Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/legacy_tf_layers/base.py: 20%

223 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15# pylint: disable=g-classes-have-attributes 

16"""Contains the base Layer class, from which all layers inherit.""" 

17import copy 

18import warnings 

19 

20from tensorflow.python.eager import context 

21from tensorflow.python.framework import dtypes 

22from tensorflow.python.framework import ops 

23from tensorflow.python.keras import backend 

24from tensorflow.python.keras.engine import base_layer 

25from tensorflow.python.keras.engine import base_layer_utils 

26from tensorflow.python.keras.legacy_tf_layers import variable_scope_shim 

27from tensorflow.python.keras.mixed_precision import policy 

28from tensorflow.python.keras.utils import tf_contextlib 

29from tensorflow.python.ops import variable_scope as vs 

30from tensorflow.python.ops import variables as tf_variables 

31from tensorflow.python.trackable import base as trackable 

32from tensorflow.python.util import nest 

33from tensorflow.python.util.tf_export import keras_export 

34 

35# Avoid breaking users who directly import this symbol from this file. 

36# TODO(fchollet): remove this. 

37InputSpec = base_layer.InputSpec # pylint: disable=invalid-name 

38 

39_KERAS_STYLE_SCOPE = False 

40 

41 

42@keras_export( 

43 v1=['keras.__internal__.legacy.layers.experimental.keras_style_scope']) 

44@tf_contextlib.contextmanager 

45def keras_style_scope(): 

46 """Use Keras-style variable management. 

47 

48 All tf.layers and tf RNN cells created in this scope use Keras-style 

49 variable management. Creating such layers with a scope= argument is 

50 disallowed, and reuse=True is disallowed. 

51 

52 The purpose of this scope is to allow users of existing layers to 

53 slowly transition to a Keras layers API without breaking existing 

54 functionality. 

55 

56 One example of this is when using TensorFlow's RNN classes with Keras 

57 Models or Networks. Because Keras models do not properly set variable 

58 scopes, users of RNNs may either accidentally share scopes between two 

59 different models, or get errors about variables that already exist. 

60 

61 Example: 

62 

63 ```python 

64 class RNNModel(tf.keras.Model): 

65 

66 def __init__(self, name): 

67 super(RNNModel, self).__init__(name=name) 

68 self.rnn = tf.compat.v1.nn.rnn_cell.MultiRNNCell( 

69 [tf.compat.v1.nn.rnn_cell.LSTMCell(64) for _ in range(2)]) 

70 

71 def call(self, input, state): 

72 return self.rnn(input, state) 

73 

74 model_1 = RNNModel("model_1") 

75 model_2 = RNNModel("model_2") 

76 

77 # OK 

78 output_1, next_state_1 = model_1(input, state) 

79 # Raises an error about trying to create an already existing variable. 

80 output_2, next_state_2 = model_2(input, state) 

81 ``` 

82 

83 The solution is to wrap the model construction and execution in a keras-style 

84 scope: 

85 

86 ```python 

87 with keras_style_scope(): 

88 model_1 = RNNModel("model_1") 

89 model_2 = RNNModel("model_2") 

90 

91 # model_1 and model_2 are guaranteed to create their own variables. 

92 output_1, next_state_1 = model_1(input, state) 

93 output_2, next_state_2 = model_2(input, state) 

94 

95 assert len(model_1.weights) > 0 

96 assert len(model_2.weights) > 0 

97 assert(model_1.weights != model_2.weights) 

98 ``` 

99 

100 Yields: 

101 A keras layer style scope. 

102 """ 

103 global _KERAS_STYLE_SCOPE 

104 stack = _KERAS_STYLE_SCOPE 

105 _KERAS_STYLE_SCOPE = True 

106 try: 

107 yield 

108 finally: 

109 _KERAS_STYLE_SCOPE = stack 

110 

111 

112@keras_export( 

113 v1=['keras.__internal__.legacy.layers.experimental.set_keras_style']) 

114def set_keras_style(): 

115 """Use Keras-style variable management. 

116 

117 All tf.layers and tf RNN cells created after keras style ha been enabled 

118 use Keras-style variable management. Creating such layers with a 

119 scope= argument is disallowed, and reuse=True is disallowed. 

120 

121 The purpose of this function is to allow users of existing layers to 

122 slowly transition to Keras layers API without breaking existing 

123 functionality. 

124 

125 For more details, see the documentation for `keras_style_scope`. 

126 

127 Note, once keras style has been set, it is set globally for the entire 

128 program and cannot be unset. 

129 

130 Example: 

131 

132 ```python 

133 set_keras_style() 

134 

135 model_1 = RNNModel(name="model_1") 

136 model_2 = RNNModel(name="model_2") 

137 

138 # model_1 and model_2 are guaranteed to create their own variables. 

139 output_1, next_state_1 = model_1(input, state) 

140 output_2, next_state_2 = model_2(input, state) 

141 

142 assert len(model_1.weights) > 0 

143 assert len(model_2.weights) > 0 

144 assert(model_1.weights != model_2.weights) 

145 ``` 

146 """ 

147 global _KERAS_STYLE_SCOPE 

148 _KERAS_STYLE_SCOPE = True 

149 

150 

151def _is_in_keras_style_scope(): 

152 global _KERAS_STYLE_SCOPE 

153 return _KERAS_STYLE_SCOPE 

154 

155 

156@keras_export(v1=['keras.__internal__.legacy.layers.Layer']) 

157class Layer(base_layer.Layer): 

158 """Base layer class. 

159 

160 It is considered legacy, and we recommend the use of `tf.keras.layers.Layer` 

161 instead. 

162 

163 Args: 

164 trainable: Boolean, whether the layer's variables should be trainable. 

165 name: String name of the layer. 

166 dtype: Default dtype of the layer's weights (default of `None` means use the 

167 type of the first input). 

168 

169 Read-only properties: 

170 name: The name of the layer (string). 

171 dtype: Default dtype of the layer's weights (default of `None` means use the 

172 type of the first input). 

173 trainable_variables: List of trainable variables. 

174 non_trainable_variables: List of non-trainable variables. 

175 variables: List of all variables of this layer, trainable and 

176 non-trainable. 

177 updates: List of update ops of this layer. 

178 losses: List of losses added by this layer. 

179 trainable_weights: List of variables to be included in backprop. 

180 non_trainable_weights: List of variables that should not be 

181 included in backprop. 

182 weights: The concatenation of the lists trainable_weights and 

183 non_trainable_weights (in this order). 

184 

185 Mutable properties: 

186 trainable: Whether the layer should be trained (boolean). 

187 input_spec: Optional (list of) `InputSpec` object(s) specifying the 

188 constraints on inputs that can be accepted by the layer. 

189 """ 

190 

191 def __init__(self, trainable=True, name=None, dtype=None, 

192 **kwargs): 

193 # For backwards compatibility, legacy layers do not use `ResourceVariable` 

194 # by default. 

195 self._use_resource_variables = False 

196 scope = kwargs.pop('_scope', None) 

197 self._reuse = kwargs.pop('_reuse', None) 

198 

199 # Avoid an incorrect lint error 

200 self._trainable_weights = [] 

201 self.built = False 

202 

203 if dtype is None: 

204 # Indicates to infer dtype from inputs. When the V2 dtype behavior is 

205 # enabled, Keras layers default their dtype to floatx instead, so we pass 

206 # an "_infer" policy to keep the old V1 behavior. 

207 dtype = policy.Policy('_infer') 

208 

209 if 'autocast' not in kwargs: 

210 kwargs['autocast'] = False 

211 

212 # Mark that legacy layers should not be instrumented as Keras usage 

213 self._disable_keras_instrumentation = True 

214 

215 super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype, 

216 **kwargs) 

217 

218 if _is_in_keras_style_scope(): 

219 if scope is not None: 

220 raise ValueError( 

221 'scope argument not allowed when keras style layers are enabled, ' 

222 'but saw: {}'.format(scope)) 

223 if self._reuse is not None: 

224 raise ValueError( 

225 'reuse argument not allowed when keras style layers are enabled, ' 

226 'but saw: {}'.format(self._reuse)) 

227 self._keras_style = True 

228 else: 

229 self._keras_style = False 

230 

231 self._call_has_scope_arg = 'scope' in self._call_fn_args 

232 if scope: 

233 with vs.variable_scope(scope) as captured_scope: 

234 self._scope = captured_scope 

235 else: 

236 self._scope = None 

237 self._current_scope = None 

238 

239 # We no longer track graph in tf.layers layers. This property is only kept to 

240 # maintain API backward compatibility. 

241 @property 

242 def graph(self): 

243 warnings.warn('`Layer.graph` is deprecated and ' 

244 'will be removed in a future version. ' 

245 'Please stop using this property because tf.layers layers no ' 

246 'longer track their graph.') 

247 if context.executing_eagerly(): 

248 raise RuntimeError('Layer.graph not supported when executing eagerly.') 

249 return None 

250 

251 def _init_set_name(self, name): 

252 # Determine layer name (non-unique). 

253 if isinstance(name, vs.VariableScope): 

254 base_name = name.name 

255 self._name, _ = self._make_unique_name() 

256 else: 

257 base_name = name 

258 self._name = name 

259 if not name: 

260 self._name, base_name = self._make_unique_name() 

261 self._base_name = base_name 

262 

263 def _make_unique_name(self, name_uid_map=None, avoid_names=None, 

264 namespace='', zero_based=False): 

265 base_name = base_layer.to_snake_case(self.__class__.__name__) 

266 name = backend.unique_object_name( 

267 base_name, 

268 name_uid_map=name_uid_map, 

269 avoid_names=avoid_names, 

270 namespace=namespace, 

271 zero_based=zero_based) 

272 return (name, base_name) 

273 

274 @property 

275 def scope_name(self): 

276 if not self._scope: 

277 raise ValueError('No name available for layer scope because the layer "' + 

278 self._name + '" has not been used yet. The scope name ' + 

279 ' is determined the first time the layer instance is ' + 

280 'called. You must therefore call the layer before ' + 

281 'querying `scope_name`.') 

282 return self._scope.name 

283 

284 def add_loss(self, losses, inputs=None): 

285 previous_losses_length = len(self._losses) 

286 previous_callable_losses_length = len(self._callable_losses) 

287 super(Layer, self).add_loss(losses, inputs=inputs) 

288 if not context.executing_eagerly(): 

289 # TODO(fchollet): deprecate collection below. 

290 new_losses = self._losses[previous_losses_length:] 

291 new_callable_losses = self._callable_losses[ 

292 previous_callable_losses_length:] 

293 for regularizer in new_callable_losses: 

294 loss_tensor = regularizer() 

295 if loss_tensor is not None: 

296 new_losses.append(loss_tensor) 

297 _add_elements_to_collection( 

298 new_losses, 

299 ops.GraphKeys.REGULARIZATION_LOSSES) 

300 

301 def _name_scope(self): # pylint: disable=method-hidden 

302 """Determines op naming for the Layer.""" 

303 if self._keras_style: 

304 return super(Layer, self)._name_scope() 

305 return self._current_scope.original_name_scope 

306 

307 def _set_scope(self, scope=None): 

308 if self._scope is None: 

309 # If constructed with _scope=None, lazy setting of scope. 

310 if self._reuse: 

311 with vs.variable_scope( 

312 scope if scope is not None else self._base_name) as captured_scope: 

313 self._scope = captured_scope 

314 else: 

315 with vs.variable_scope( 

316 scope, default_name=self._base_name) as captured_scope: 

317 self._scope = captured_scope 

318 

319 def add_weight(self, 

320 name, 

321 shape, 

322 dtype=None, 

323 initializer=None, 

324 regularizer=None, 

325 trainable=None, 

326 constraint=None, 

327 use_resource=None, 

328 synchronization=vs.VariableSynchronization.AUTO, 

329 aggregation=vs.VariableAggregation.NONE, 

330 partitioner=None, 

331 **kwargs): 

332 """Adds a new variable to the layer, or gets an existing one; returns it. 

333 

334 Args: 

335 name: variable name. 

336 shape: variable shape. 

337 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 

338 initializer: initializer instance (callable). 

339 regularizer: regularizer instance (callable). 

340 trainable: whether the variable should be part of the layer's 

341 "trainable_variables" (e.g. variables, biases) 

342 or "non_trainable_variables" (e.g. BatchNorm mean, stddev). 

343 Note, if the current variable scope is marked as non-trainable 

344 then this parameter is ignored and any added variables are also 

345 marked as non-trainable. `trainable` defaults to `True` unless 

346 `synchronization` is set to `ON_READ`. 

347 constraint: constraint instance (callable). 

348 use_resource: Whether to use `ResourceVariable`. 

349 synchronization: Indicates when a distributed a variable will be 

350 aggregated. Accepted values are constants defined in the class 

351 `tf.VariableSynchronization`. By default the synchronization is set to 

352 `AUTO` and the current `DistributionStrategy` chooses 

353 when to synchronize. If `synchronization` is set to `ON_READ`, 

354 `trainable` must not be set to `True`. 

355 aggregation: Indicates how a distributed variable will be aggregated. 

356 Accepted values are constants defined in the class 

357 `tf.VariableAggregation`. 

358 partitioner: (optional) partitioner instance (callable). If 

359 provided, when the requested variable is created it will be split 

360 into multiple partitions according to `partitioner`. In this case, 

361 an instance of `PartitionedVariable` is returned. Available 

362 partitioners include `tf.compat.v1.fixed_size_partitioner` and 

363 `tf.compat.v1.variable_axis_size_partitioner`. For more details, see 

364 the documentation of `tf.compat.v1.get_variable` and the "Variable 

365 Partitioners and Sharding" section of the API guide. 

366 **kwargs: Additional keyword arguments. 

367 

368 Returns: 

369 The created variable. Usually either a `Variable` or `ResourceVariable` 

370 instance. If `partitioner` is not `None`, a `PartitionedVariable` 

371 instance is returned. 

372 

373 Raises: 

374 RuntimeError: If called with partitioned variable regularization and 

375 eager execution is enabled. 

376 ValueError: When trainable has been set to True with synchronization 

377 set as `ON_READ`. 

378 """ 

379 for kwarg in kwargs: 

380 if kwarg != 'experimental_autocast': 

381 raise TypeError('Unknown keyword argument:', kwarg) 

382 if self._keras_style: 

383 return super(Layer, self).add_weight( 

384 name=name, 

385 shape=shape, 

386 dtype=dtype, 

387 initializer=initializer, 

388 regularizer=regularizer, 

389 trainable=trainable and self.trainable, 

390 constraint=constraint, 

391 use_resource=use_resource, 

392 synchronization=vs.VariableSynchronization.AUTO, 

393 aggregation=vs.VariableAggregation.NONE, 

394 partitioner=partitioner, 

395 **kwargs) 

396 

397 if synchronization == vs.VariableSynchronization.ON_READ: 

398 if trainable: 

399 raise ValueError( 

400 'Synchronization value can be set to ' 

401 'VariableSynchronization.ON_READ only for non-trainable variables. ' 

402 'You have specified trainable=True and ' 

403 'synchronization=VariableSynchronization.ON_READ.') 

404 else: 

405 # Set trainable to be false when variable is to be synced on read. 

406 trainable = False 

407 elif trainable is None: 

408 trainable = True 

409 

410 def _should_add_regularizer(variable, existing_variable_set): 

411 if base_layer_utils.is_split_variable(variable): 

412 for var in variable: 

413 if var in existing_variable_set: 

414 return False 

415 return True 

416 else: 

417 return variable not in existing_variable_set 

418 

419 init_graph = None 

420 if not context.executing_eagerly(): 

421 default_graph = ops.get_default_graph() 

422 if default_graph.building_function: 

423 with ops.init_scope(): 

424 # Retrieve the variables from the graph into which variables 

425 # will be lifted; if initialization ops will be lifted into 

426 # the eager context, then there is nothing to retrieve, since variable 

427 # collections are not supported when eager execution is enabled. 

428 if not context.executing_eagerly(): 

429 init_graph = ops.get_default_graph() 

430 existing_variables = set(tf_variables.global_variables()) 

431 else: 

432 # Initialization ops will not be lifted out of the default graph. 

433 init_graph = default_graph 

434 existing_variables = set(tf_variables.global_variables()) 

435 

436 if dtype is None: 

437 dtype = self.dtype or dtypes.float32 

438 

439 self._set_scope(None) 

440 reuse = self.built or self._reuse 

441 prev_len_trainable = len(self._trainable_weights) 

442 with vs.variable_scope( 

443 self._scope, reuse=reuse, auxiliary_name_scope=False) as scope: 

444 self._current_scope = scope 

445 with backend.name_scope(self._name_scope()): # pylint: disable=not-callable 

446 use_resource = (use_resource or 

447 self._use_resource_variables or 

448 scope.use_resource) 

449 if initializer is None: 

450 initializer = scope.initializer 

451 variable = super(Layer, self).add_weight( 

452 name, 

453 shape, 

454 dtype=dtypes.as_dtype(dtype), 

455 initializer=initializer, 

456 trainable=trainable and self.trainable, 

457 constraint=constraint, 

458 partitioner=partitioner, 

459 use_resource=use_resource, 

460 synchronization=synchronization, 

461 aggregation=aggregation, 

462 getter=vs.get_variable, 

463 **kwargs) 

464 

465 if regularizer: 

466 if (ops.executing_eagerly_outside_functions() 

467 or _should_add_regularizer(variable, existing_variables)): 

468 self._handle_weight_regularization(name, variable, regularizer) 

469 var_store = vs._get_default_variable_store() # pylint: disable=protected-access 

470 # When the shim to get variable scope working in TF2 is used, 

471 # We need to explicitly make the shim track the regularization 

472 # losses as the collections will not be accessible. 

473 if hasattr(var_store, 'add_regularizer'): 

474 var_store.add_regularizer(variable, regularizer) 

475 

476 if init_graph is not None: 

477 # Handle edge case where a custom getter has overridden `trainable`. 

478 # There is one known occurrence of this, in unit test 

479 # testBasicRNNCellNotTrainable in 

480 # contrib.rnn.python.kernel_tests.core_rnn_cell_test 

481 with init_graph.as_default(): 

482 trainable_variables = tf_variables.trainable_variables() 

483 if (trainable and self.trainable and 

484 variable not in trainable_variables): 

485 # A custom getter / variable scope overrode the trainable flag. 

486 extra_trainable_vars = self._trainable_weights[prev_len_trainable:] 

487 self._trainable_weights = self._trainable_weights[ 

488 :prev_len_trainable] 

489 self._non_trainable_weights += extra_trainable_vars 

490 return variable 

491 

492 def __call__(self, inputs, *args, **kwargs): 

493 """Wraps `call`, applying pre- and post-processing steps. 

494 

495 Args: 

496 inputs: input tensor(s). 

497 *args: additional positional arguments to be passed to `self.call`. 

498 **kwargs: additional keyword arguments to be passed to `self.call`. 

499 **Note**: kwarg `scope` is reserved for use by the layer. 

500 

501 Returns: 

502 Output tensor(s). 

503 

504 Note: 

505 - If the layer's `call` method takes a `scope` keyword argument, 

506 this argument will be automatically set to the current variable scope. 

507 - If the layer's `call` method takes a `mask` argument (as some Keras 

508 layers do), its default value will be set to the mask generated 

509 for `inputs` by the previous layer (if `input` did come from 

510 a layer that generated a corresponding mask, i.e. if it came from 

511 a Keras layer with masking support. 

512 

513 Raises: 

514 ValueError: if the layer's `call` method returns None (an invalid value). 

515 """ 

516 scope = kwargs.pop('scope', None) 

517 

518 if self._keras_style: 

519 if scope is not None: 

520 raise ValueError( 

521 'scope argument not allowed when keras style layers are enabled, ' 

522 'but saw: {}'.format(scope)) 

523 return super(Layer, self).__call__(inputs, *args, **kwargs) 

524 

525 self._set_scope(scope) 

526 

527 if self.built: 

528 try: 

529 # Some classes which inherit from Layer do not use its constructor, so 

530 # rather than initializing to None we check for an AttributeError. 

531 scope_context_manager = self._always_reuse_variable_scope # pylint: disable=access-member-before-definition 

532 except AttributeError: 

533 scope_context_manager = None 

534 

535 if scope_context_manager is None: 

536 # From this point we will always set reuse=True, so create a "final" 

537 # variable scope with this setting. We avoid re-creating variable scopes 

538 # after this point as an optimization. 

539 scope_context_manager = vs.variable_scope( 

540 self._scope, reuse=True, auxiliary_name_scope=False) 

541 

542 # Do not cache variable scopes if Eager mode is enabled. If Eager mode 

543 # is enabled then we don't want to reuse scopes because the cached scope 

544 # might be from a FuncGraph or Eager scope we are no longer in. 

545 if not ops.executing_eagerly_outside_functions(): 

546 self._always_reuse_variable_scope = scope_context_manager 

547 else: 

548 scope_context_manager = vs.variable_scope( 

549 self._scope, reuse=self._reuse, auxiliary_name_scope=False) 

550 

551 with scope_context_manager as scope: 

552 self._current_scope = scope 

553 

554 try: 

555 call_has_scope_arg = self._call_has_scope_arg 

556 except AttributeError: 

557 self._call_fn_args = variable_scope_shim.fn_args(self.call) 

558 self._call_has_scope_arg = 'scope' in self._call_fn_args 

559 call_has_scope_arg = self._call_has_scope_arg 

560 if call_has_scope_arg: 

561 kwargs['scope'] = scope 

562 

563 # Actually call layer 

564 outputs = super(Layer, self).__call__(inputs, *args, **kwargs) 

565 

566 if not context.executing_eagerly(): 

567 # Update global default collections. 

568 _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) 

569 return outputs 

570 

571 def __deepcopy__(self, memo): 

572 no_copy = set(['_graph', '_thread_local', '_metrics_lock']) 

573 shallow_copy = set(['_scope', '_always_reuse_variable_scope']) 

574 cls = self.__class__ 

575 result = cls.__new__(cls) 

576 memo[id(self)] = result 

577 for k, v in self.__dict__.items(): 

578 if k in no_copy: 

579 setattr(result, k, v) 

580 elif k in shallow_copy: 

581 setattr(result, k, copy.copy(v)) 

582 elif base_layer.is_tensor_or_tensor_list(v): 

583 setattr(result, k, v) 

584 else: 

585 setattr(result, k, copy.deepcopy(v, memo)) 

586 return result 

587 

588 def __setattr__(self, value, name): 

589 # By-pass the automatic dependency tracking performed by the parent Layer. 

590 super(trackable.Trackable, self).__setattr__(value, name) # pylint: disable=bad-super-call 

591 

592 @property 

593 def _is_legacy_layer(self): 

594 """Used by keras to check compatibility. This should not be overridden.""" 

595 return True 

596 

597 

598def _add_elements_to_collection(elements, collection_list): 

599 if context.executing_eagerly(): 

600 raise RuntimeError('Using collections from Layers not supported in Eager ' 

601 'mode. Tried to add %s to %s' % (elements, 

602 collection_list)) 

603 elements = nest.flatten(elements) 

604 collection_list = nest.flatten(collection_list) 

605 for name in collection_list: 

606 collection = ops.get_collection_ref(name) 

607 collection_set = {id(e) for e in collection} 

608 for element in elements: 

609 if id(element) not in collection_set: 

610 collection.append(element)