Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/cell_wrappers.py: 26%

229 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Module implementing RNN wrappers.""" 

16 

17 

18# Note that all the APIs under this module are exported as tf.nn.*. This is due 

19# to the fact that those APIs were from tf.nn.rnn_cell_impl. They are ported 

20# here to avoid the cyclic dependency issue for serialization. These APIs will 

21# probably be deprecated and removed in future since similar API is available in 

22# existing Keras RNN API. 

23 

24import hashlib 

25import numbers 

26import sys 

27import types as python_types 

28import warnings 

29 

30import tensorflow.compat.v2 as tf 

31 

32from keras.src.layers.rnn import lstm 

33from keras.src.layers.rnn.abstract_rnn_cell import AbstractRNNCell 

34from keras.src.saving import serialization_lib 

35from keras.src.utils import generic_utils 

36from keras.src.utils import tf_inspect 

37 

38# isort: off 

39from tensorflow.python.util.tf_export import tf_export 

40from tensorflow.python.util.deprecation import deprecated 

41 

42 

43class _RNNCellWrapper(AbstractRNNCell): 

44 """Base class for cells wrappers V2 compatibility. 

45 

46 This class along with `rnn_cell_impl._RNNCellWrapperV1` allows to define 

47 wrappers that are compatible with V1 and V2, and defines helper methods for 

48 this purpose. 

49 """ 

50 

51 def __init__(self, cell, *args, **kwargs): 

52 super().__init__(*args, **kwargs) 

53 self.cell = cell 

54 cell_call_spec = tf_inspect.getfullargspec(cell.call) 

55 self._call_spec.expects_training_arg = ( 

56 "training" in cell_call_spec.args 

57 ) or (cell_call_spec.varkw is not None) 

58 

59 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 

60 """Calls the wrapped cell and performs the wrapping logic. 

61 

62 This method is called from the wrapper's `call` or `__call__` methods. 

63 

64 Args: 

65 inputs: A tensor with wrapped cell's input. 

66 state: A tensor or tuple of tensors with wrapped cell's state. 

67 cell_call_fn: Wrapped cell's method to use for step computation 

68 (cell's `__call__` or 'call' method). 

69 **kwargs: Additional arguments. 

70 

71 Returns: 

72 A pair containing: 

73 - Output: A tensor with cell's output. 

74 - New state: A tensor or tuple of tensors with new wrapped cell's 

75 state. 

76 """ 

77 raise NotImplementedError 

78 

79 def call(self, inputs, state, **kwargs): 

80 """Runs the RNN cell step computation. 

81 

82 When `call` is being used, we assume that the wrapper object has been 

83 built, and therefore the wrapped cells has been built via its `build` 

84 method and its `call` method can be used directly. 

85 

86 This allows to use the wrapped cell and the non-wrapped cell 

87 equivalently when using `call` and `build`. 

88 

89 Args: 

90 inputs: A tensor with wrapped cell's input. 

91 state: A tensor or tuple of tensors with wrapped cell's state. 

92 **kwargs: Additional arguments passed to the wrapped cell's `call`. 

93 

94 Returns: 

95 A pair containing: 

96 

97 - Output: A tensor with cell's output. 

98 - New state: A tensor or tuple of tensors with new wrapped cell's 

99 state. 

100 """ 

101 return self._call_wrapped_cell( 

102 inputs, state, cell_call_fn=self.cell.call, **kwargs 

103 ) 

104 

105 def build(self, inputs_shape): 

106 """Builds the wrapped cell.""" 

107 self.cell.build(inputs_shape) 

108 self.built = True 

109 

110 @property 

111 def wrapped_cell(self): 

112 return self.cell 

113 

114 @property 

115 def state_size(self): 

116 return self.cell.state_size 

117 

118 @property 

119 def output_size(self): 

120 return self.cell.output_size 

121 

122 def zero_state(self, batch_size, dtype): 

123 with tf.name_scope(type(self).__name__ + "ZeroState"): 

124 return self.cell.zero_state(batch_size, dtype) 

125 

126 def get_config(self): 

127 config = { 

128 "cell": { 

129 "class_name": self.cell.__class__.__name__, 

130 "config": self.cell.get_config(), 

131 }, 

132 } 

133 base_config = super().get_config() 

134 return dict(list(base_config.items()) + list(config.items())) 

135 

136 @classmethod 

137 def from_config(cls, config, custom_objects=None): 

138 config = config.copy() 

139 from keras.src.layers.serialization import deserialize as deserialize_layer 

140 

141 cell = deserialize_layer( 

142 config.pop("cell"), custom_objects=custom_objects 

143 ) 

144 return cls(cell, **config) 

145 

146 

147@deprecated(None, "Please use tf.keras.layers.RNN instead.") 

148@tf_export("nn.RNNCellDropoutWrapper", v1=[]) 

149class DropoutWrapper(_RNNCellWrapper): 

150 """Operator adding dropout to inputs and outputs of the given cell.""" 

151 

152 def __init__( 

153 self, 

154 cell, 

155 input_keep_prob=1.0, 

156 output_keep_prob=1.0, 

157 state_keep_prob=1.0, 

158 variational_recurrent=False, 

159 input_size=None, 

160 dtype=None, 

161 seed=None, 

162 dropout_state_filter_visitor=None, 

163 **kwargs, 

164 ): 

165 """Create a cell with added input, state, and/or output dropout. 

166 

167 If `variational_recurrent` is set to `True` (**NOT** the default 

168 behavior), then the same dropout mask is applied at every step, as 

169 described in: [A Theoretically Grounded Application of Dropout in 

170 Recurrent Neural Networks. Y. Gal, Z. 

171 Ghahramani](https://arxiv.org/abs/1512.05287). 

172 

173 Otherwise a different dropout mask is applied at every time step. 

174 

175 Note, by default (unless a custom `dropout_state_filter` is provided), 

176 the memory state (`c` component of any `LSTMStateTuple`) passing through 

177 a `DropoutWrapper` is never modified. This behavior is described in the 

178 above article. 

179 

180 Args: 

181 cell: an RNNCell, a projection to output_size is added to it. 

182 input_keep_prob: unit Tensor or float between 0 and 1, input keep 

183 probability; if it is constant and 1, no input dropout will be 

184 added. 

185 output_keep_prob: unit Tensor or float between 0 and 1, output keep 

186 probability; if it is constant and 1, no output dropout will be 

187 added. 

188 state_keep_prob: unit Tensor or float between 0 and 1, output keep 

189 probability; if it is constant and 1, no output dropout will be 

190 added. State dropout is performed on the outgoing states of the 

191 cell. **Note** the state components to which dropout is applied when 

192 `state_keep_prob` is in `(0, 1)` are also determined by the argument 

193 `dropout_state_filter_visitor` (e.g. by default dropout is never 

194 applied to the `c` component of an `LSTMStateTuple`). 

195 variational_recurrent: Python bool. If `True`, then the same dropout 

196 pattern is applied across all time steps per run call. If this 

197 parameter is set, `input_size` **must** be provided. 

198 input_size: (optional) (possibly nested tuple of) `TensorShape` 

199 objects containing the depth(s) of the input tensors expected to be 

200 passed in to the `DropoutWrapper`. Required and used **iff** 

201 `variational_recurrent = True` and `input_keep_prob < 1`. 

202 dtype: (optional) The `dtype` of the input, state, and output tensors. 

203 Required and used **iff** `variational_recurrent = True`. 

204 seed: (optional) integer, the randomness seed. 

205 dropout_state_filter_visitor: (optional), default: (see below). 

206 Function that takes any hierarchical level of the state and returns 

207 a scalar or depth=1 structure of Python booleans describing which 

208 terms in the state should be dropped out. In addition, if the 

209 function returns `True`, dropout is applied across this sublevel. 

210 If the function returns `False`, dropout is not applied across this 

211 entire sublevel. Default behavior: perform dropout on all terms 

212 except the memory (`c`) state of `LSTMCellState` objects, and don't 

213 try to apply dropout to 

214 `TensorArray` objects: 

215 ``` 

216 def dropout_state_filter_visitor(s): 

217 # Never perform dropout on the c state. 

218 if isinstance(s, LSTMCellState): 

219 return LSTMCellState(c=False, h=True) 

220 elif isinstance(s, TensorArray): 

221 return False 

222 return True 

223 ``` 

224 **kwargs: dict of keyword arguments for base layer. 

225 

226 Raises: 

227 TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is 

228 provided but not `callable`. 

229 ValueError: if any of the keep_probs are not between 0 and 1. 

230 """ 

231 if isinstance(cell, lstm.LSTMCell): 

232 raise ValueError( 

233 "keras LSTM cell does not work with DropoutWrapper. " 

234 "Please use LSTMCell(dropout=x, recurrent_dropout=y) " 

235 "instead." 

236 ) 

237 super().__init__(cell, dtype=dtype, **kwargs) 

238 

239 if dropout_state_filter_visitor is not None and not callable( 

240 dropout_state_filter_visitor 

241 ): 

242 raise TypeError( 

243 "dropout_state_filter_visitor must be callable. " 

244 f"Received: {dropout_state_filter_visitor}" 

245 ) 

246 self._dropout_state_filter = ( 

247 dropout_state_filter_visitor 

248 or _default_dropout_state_filter_visitor 

249 ) 

250 with tf.name_scope("DropoutWrapperInit"): 

251 

252 def tensor_and_const_value(v): 

253 tensor_value = tf.convert_to_tensor(v) 

254 const_value = tf.get_static_value(tensor_value) 

255 return (tensor_value, const_value) 

256 

257 for prob, attr in [ 

258 (input_keep_prob, "input_keep_prob"), 

259 (state_keep_prob, "state_keep_prob"), 

260 (output_keep_prob, "output_keep_prob"), 

261 ]: 

262 tensor_prob, const_prob = tensor_and_const_value(prob) 

263 if const_prob is not None: 

264 if const_prob < 0 or const_prob > 1: 

265 raise ValueError( 

266 f"Parameter {attr} must be between 0 and 1. " 

267 f"Received {const_prob}" 

268 ) 

269 setattr(self, f"_{attr}", float(const_prob)) 

270 else: 

271 setattr(self, f"_{attr}", tensor_prob) 

272 

273 # Set variational_recurrent, seed before running the code below 

274 self._variational_recurrent = variational_recurrent 

275 self._input_size = input_size 

276 self._seed = seed 

277 

278 self._recurrent_input_noise = None 

279 self._recurrent_state_noise = None 

280 self._recurrent_output_noise = None 

281 

282 if variational_recurrent: 

283 if dtype is None: 

284 raise ValueError( 

285 "When variational_recurrent=True, dtype must be provided" 

286 ) 

287 

288 def convert_to_batch_shape(s): 

289 # Prepend a 1 for the batch dimension; for recurrent 

290 # variational dropout we use the same dropout mask for all 

291 # batch elements. 

292 return tf.concat(([1], tf.TensorShape(s).as_list()), 0) 

293 

294 def batch_noise(s, inner_seed): 

295 shape = convert_to_batch_shape(s) 

296 return tf.random.uniform(shape, seed=inner_seed, dtype=dtype) 

297 

298 if ( 

299 not isinstance(self._input_keep_prob, numbers.Real) 

300 or self._input_keep_prob < 1.0 

301 ): 

302 if input_size is None: 

303 raise ValueError( 

304 "When variational_recurrent=True and input_keep_prob < " 

305 "1.0 or is unknown, input_size must be provided" 

306 ) 

307 self._recurrent_input_noise = _enumerated_map_structure_up_to( 

308 input_size, 

309 lambda i, s: batch_noise( 

310 s, inner_seed=self._gen_seed("input", i) 

311 ), 

312 input_size, 

313 ) 

314 self._recurrent_state_noise = _enumerated_map_structure_up_to( 

315 cell.state_size, 

316 lambda i, s: batch_noise( 

317 s, inner_seed=self._gen_seed("state", i) 

318 ), 

319 cell.state_size, 

320 ) 

321 self._recurrent_output_noise = _enumerated_map_structure_up_to( 

322 cell.output_size, 

323 lambda i, s: batch_noise( 

324 s, inner_seed=self._gen_seed("output", i) 

325 ), 

326 cell.output_size, 

327 ) 

328 

329 def _gen_seed(self, salt_prefix, index): 

330 if self._seed is None: 

331 return None 

332 salt = "%s_%d" % (salt_prefix, index) 

333 string = (str(self._seed) + salt).encode("utf-8") 

334 return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 

335 

336 def _variational_recurrent_dropout_value( 

337 self, unused_index, value, noise, keep_prob 

338 ): 

339 """Performs dropout given the pre-calculated noise tensor.""" 

340 # uniform [keep_prob, 1.0 + keep_prob) 

341 random_tensor = keep_prob + noise 

342 

343 # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) 

344 binary_tensor = tf.floor(random_tensor) 

345 ret = tf.divide(value, keep_prob) * binary_tensor 

346 ret.set_shape(value.get_shape()) 

347 return ret 

348 

349 def _dropout( 

350 self, 

351 values, 

352 salt_prefix, 

353 recurrent_noise, 

354 keep_prob, 

355 shallow_filtered_substructure=None, 

356 ): 

357 """Decides whether to perform standard dropout or recurrent dropout.""" 

358 

359 if shallow_filtered_substructure is None: 

360 # Put something so we traverse the entire structure; inside the 

361 # dropout function we check to see if leafs of this are bool or not. 

362 shallow_filtered_substructure = values 

363 

364 if not self._variational_recurrent: 

365 

366 def dropout(i, do_dropout, v): 

367 if not isinstance(do_dropout, bool) or do_dropout: 

368 return tf.nn.dropout( 

369 v, 

370 rate=1.0 - keep_prob, 

371 seed=self._gen_seed(salt_prefix, i), 

372 ) 

373 else: 

374 return v 

375 

376 return _enumerated_map_structure_up_to( 

377 shallow_filtered_substructure, 

378 dropout, 

379 *[shallow_filtered_substructure, values], 

380 ) 

381 else: 

382 

383 def dropout(i, do_dropout, v, n): 

384 if not isinstance(do_dropout, bool) or do_dropout: 

385 return self._variational_recurrent_dropout_value( 

386 i, v, n, keep_prob 

387 ) 

388 else: 

389 return v 

390 

391 return _enumerated_map_structure_up_to( 

392 shallow_filtered_substructure, 

393 dropout, 

394 *[shallow_filtered_substructure, values, recurrent_noise], 

395 ) 

396 

397 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 

398 """Runs the wrapped cell and applies dropout. 

399 

400 Args: 

401 inputs: A tensor with wrapped cell's input. 

402 state: A tensor or tuple of tensors with wrapped cell's state. 

403 cell_call_fn: Wrapped cell's method to use for step computation 

404 (cell's `__call__` or 'call' method). 

405 **kwargs: Additional arguments. 

406 

407 Returns: 

408 A pair containing: 

409 

410 - Output: A tensor with cell's output. 

411 - New state: A tensor or tuple of tensors with new wrapped cell's 

412 state. 

413 """ 

414 

415 def _should_dropout(p): 

416 return (not isinstance(p, float)) or p < 1 

417 

418 if _should_dropout(self._input_keep_prob): 

419 inputs = self._dropout( 

420 inputs, 

421 "input", 

422 self._recurrent_input_noise, 

423 self._input_keep_prob, 

424 ) 

425 output, new_state = cell_call_fn(inputs, state, **kwargs) 

426 if _should_dropout(self._state_keep_prob): 

427 # Identify which subsets of the state to perform dropout on and 

428 # which ones to keep. 

429 shallow_filtered_substructure = ( 

430 tf.__internal__.nest.get_traverse_shallow_structure( 

431 self._dropout_state_filter, new_state 

432 ) 

433 ) 

434 new_state = self._dropout( 

435 new_state, 

436 "state", 

437 self._recurrent_state_noise, 

438 self._state_keep_prob, 

439 shallow_filtered_substructure, 

440 ) 

441 if _should_dropout(self._output_keep_prob): 

442 output = self._dropout( 

443 output, 

444 "output", 

445 self._recurrent_output_noise, 

446 self._output_keep_prob, 

447 ) 

448 return output, new_state 

449 

450 def get_config(self): 

451 """Returns the config of the dropout wrapper.""" 

452 config = { 

453 "input_keep_prob": self._input_keep_prob, 

454 "output_keep_prob": self._output_keep_prob, 

455 "state_keep_prob": self._state_keep_prob, 

456 "variational_recurrent": self._variational_recurrent, 

457 "input_size": self._input_size, 

458 "seed": self._seed, 

459 } 

460 if self._dropout_state_filter != _default_dropout_state_filter_visitor: 

461 ( 

462 function, 

463 function_type, 

464 function_module, 

465 ) = _serialize_function_to_config(self._dropout_state_filter) 

466 config.update( 

467 { 

468 "dropout_fn": function, 

469 "dropout_fn_type": function_type, 

470 "dropout_fn_module": function_module, 

471 } 

472 ) 

473 base_config = super().get_config() 

474 return dict(list(base_config.items()) + list(config.items())) 

475 

476 @classmethod 

477 def from_config(cls, config, custom_objects=None): 

478 if "dropout_fn" in config: 

479 config = config.copy() 

480 dropout_state_filter = _parse_config_to_function( 

481 config, 

482 custom_objects, 

483 "dropout_fn", 

484 "dropout_fn_type", 

485 "dropout_fn_module", 

486 ) 

487 config.pop("dropout_fn") 

488 config["dropout_state_filter_visitor"] = dropout_state_filter 

489 return super(DropoutWrapper, cls).from_config( 

490 config, custom_objects=custom_objects 

491 ) 

492 

493 

494@deprecated(None, "Please use tf.keras.layers.RNN instead.") 

495@tf_export("nn.RNNCellResidualWrapper", v1=[]) 

496class ResidualWrapper(_RNNCellWrapper): 

497 """RNNCell wrapper that ensures cell inputs are added to the outputs.""" 

498 

499 def __init__(self, cell, residual_fn=None, **kwargs): 

500 """Constructs a `ResidualWrapper` for `cell`. 

501 

502 Args: 

503 cell: An instance of `RNNCell`. 

504 residual_fn: (Optional) The function to map raw cell inputs and raw 

505 cell outputs to the actual cell outputs of the residual network. 

506 Defaults to calling nest.map_structure on (lambda i, o: i + o), 

507 inputs and outputs. 

508 **kwargs: dict of keyword arguments for base layer. 

509 """ 

510 super().__init__(cell, **kwargs) 

511 self._residual_fn = residual_fn 

512 

513 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 

514 """Run the cell and apply the residual_fn. 

515 

516 Args: 

517 inputs: cell inputs. 

518 state: cell state. 

519 cell_call_fn: Wrapped cell's method to use for step computation 

520 (cell's `__call__` or 'call' method). 

521 **kwargs: Additional arguments passed to the wrapped cell's `call`. 

522 

523 Returns: 

524 Tuple of cell outputs and new state. 

525 

526 Raises: 

527 TypeError: If cell inputs and outputs have different structure (type). 

528 ValueError: If cell inputs and outputs have different structure 

529 (value). 

530 """ 

531 outputs, new_state = cell_call_fn(inputs, state, **kwargs) 

532 

533 # Ensure shapes match 

534 def assert_shape_match(inp, out): 

535 inp.get_shape().assert_is_compatible_with(out.get_shape()) 

536 

537 def default_residual_fn(inputs, outputs): 

538 tf.nest.assert_same_structure(inputs, outputs) 

539 tf.nest.map_structure(assert_shape_match, inputs, outputs) 

540 return tf.nest.map_structure( 

541 lambda inp, out: inp + out, inputs, outputs 

542 ) 

543 

544 res_outputs = (self._residual_fn or default_residual_fn)( 

545 inputs, outputs 

546 ) 

547 return (res_outputs, new_state) 

548 

549 def get_config(self): 

550 """Returns the config of the residual wrapper.""" 

551 if self._residual_fn is not None: 

552 ( 

553 function, 

554 function_type, 

555 function_module, 

556 ) = _serialize_function_to_config(self._residual_fn) 

557 config = { 

558 "residual_fn": function, 

559 "residual_fn_type": function_type, 

560 "residual_fn_module": function_module, 

561 } 

562 else: 

563 config = {} 

564 base_config = super().get_config() 

565 return dict(list(base_config.items()) + list(config.items())) 

566 

567 @classmethod 

568 def from_config(cls, config, custom_objects=None): 

569 if "residual_fn" in config: 

570 config = config.copy() 

571 residual_function = _parse_config_to_function( 

572 config, 

573 custom_objects, 

574 "residual_fn", 

575 "residual_fn_type", 

576 "residual_fn_module", 

577 ) 

578 config["residual_fn"] = residual_function 

579 return super(ResidualWrapper, cls).from_config( 

580 config, custom_objects=custom_objects 

581 ) 

582 

583 

584@deprecated(None, "Please use tf.keras.layers.RNN instead.") 

585@tf_export("nn.RNNCellDeviceWrapper", v1=[]) 

586class DeviceWrapper(_RNNCellWrapper): 

587 """Operator that ensures an RNNCell runs on a particular device.""" 

588 

589 def __init__(self, cell, device, **kwargs): 

590 """Construct a `DeviceWrapper` for `cell` with device `device`. 

591 

592 Ensures the wrapped `cell` is called with `tf.device(device)`. 

593 

594 Args: 

595 cell: An instance of `RNNCell`. 

596 device: A device string or function, for passing to `tf.device`. 

597 **kwargs: dict of keyword arguments for base layer. 

598 """ 

599 super().__init__(cell, **kwargs) 

600 self._device = device 

601 

602 def zero_state(self, batch_size, dtype): 

603 with tf.name_scope(type(self).__name__ + "ZeroState"): 

604 with tf.compat.v1.device(self._device): 

605 return self.cell.zero_state(batch_size, dtype) 

606 

607 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 

608 """Run the cell on specified device.""" 

609 with tf.compat.v1.device(self._device): 

610 return cell_call_fn(inputs, state, **kwargs) 

611 

612 def get_config(self): 

613 config = {"device": self._device} 

614 base_config = super().get_config() 

615 return dict(list(base_config.items()) + list(config.items())) 

616 

617 

618def _serialize_function_to_config(function): 

619 """Serialize the function for get_config().""" 

620 if isinstance(function, python_types.LambdaType): 

621 output = generic_utils.func_dump(function) 

622 output_type = "lambda" 

623 module = function.__module__ 

624 elif callable(function): 

625 output = function.__name__ 

626 output_type = "function" 

627 module = function.__module__ 

628 else: 

629 raise ValueError( 

630 f"Unrecognized function type for input: {type(function)}" 

631 ) 

632 

633 return output, output_type, module 

634 

635 

636def _parse_config_to_function( 

637 config, 

638 custom_objects, 

639 func_attr_name, 

640 func_type_attr_name, 

641 module_attr_name, 

642): 

643 """Reconstruct the function from the config.""" 

644 globs = globals() 

645 module = config.pop(module_attr_name, None) 

646 if module in sys.modules: 

647 globs.update(sys.modules[module].__dict__) 

648 elif module is not None: 

649 # Note: we don't know the name of the function if it's a lambda. 

650 warnings.warn( 

651 "{} is not loaded, but a layer uses it. " 

652 "It may cause errors.".format(module), 

653 UserWarning, 

654 stacklevel=2, 

655 ) 

656 if custom_objects: 

657 globs.update(custom_objects) 

658 function_type = config.pop(func_type_attr_name) 

659 if function_type == "function": 

660 # Simple lookup in custom objects 

661 function = serialization_lib.deserialize_keras_object( 

662 config[func_attr_name], 

663 custom_objects=custom_objects, 

664 printable_module_name="function in wrapper", 

665 ) 

666 elif function_type == "lambda": 

667 if serialization_lib.in_safe_mode(): 

668 raise ValueError( 

669 "Requested the deserialization of a layer with a " 

670 "Python `lambda` inside it. " 

671 "This carries a potential risk of arbitrary code execution " 

672 "and thus it is disallowed by default. If you trust the " 

673 "source of the saved model, you can pass `safe_mode=False` to " 

674 "the loading function in order to allow " 

675 "`lambda` loading." 

676 ) 

677 # Unsafe deserialization from bytecode 

678 function = generic_utils.func_load(config[func_attr_name], globs=globs) 

679 else: 

680 raise TypeError( 

681 f"Unknown function type received: {function_type}. " 

682 "Expected types are ['function', 'lambda']" 

683 ) 

684 return function 

685 

686 

687def _default_dropout_state_filter_visitor(substate): 

688 return not isinstance(substate, tf.TensorArray) 

689 

690 

691def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs): 

692 ix = [0] 

693 

694 def enumerated_fn(*inner_args, **inner_kwargs): 

695 r = map_fn(ix[0], *inner_args, **inner_kwargs) 

696 ix[0] += 1 

697 return r 

698 

699 return tf.__internal__.nest.map_structure_up_to( 

700 shallow_structure, enumerated_fn, *args, **kwargs 

701 ) 

702