Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_preprocessing_layer.py: 35%

202 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains the base ProcessingLayer and a subclass that uses Combiners.""" 

16 

17import abc 

18import collections 

19 

20import numpy as np 

21 

22from tensorflow.python.eager import context 

23from tensorflow.python.eager import def_function 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import ops 

26from tensorflow.python.framework import sparse_tensor 

27from tensorflow.python.keras import backend 

28from tensorflow.python.keras.engine import data_adapter 

29from tensorflow.python.keras.engine.base_layer import Layer 

30from tensorflow.python.keras.utils import tf_utils 

31from tensorflow.python.keras.utils import version_utils 

32from tensorflow.python.ops import math_ops 

33from tensorflow.python.ops import sparse_ops 

34from tensorflow.python.ops import variables 

35from tensorflow.python.ops.ragged import ragged_tensor 

36from tensorflow.python.trackable import base as trackable 

37from tensorflow.python.util.tf_export import keras_export 

38 

39 

40@keras_export('keras.layers.experimental.preprocessing.PreprocessingLayer') 

41class PreprocessingLayer(Layer, metaclass=abc.ABCMeta): 

42 """Base class for Preprocessing Layers. 

43 

44 **Don't use this class directly: it's an abstract base class!** You may 

45 be looking for one of the many built-in 

46 [preprocessing layers](https://keras.io/guides/preprocessing_layers/) 

47 instead. 

48 

49 Preprocessing layers are layers whose state gets computed before model 

50 training starts. They do not get updated during training. 

51 Most preprocessing layers implement an `adapt()` method for state computation. 

52 

53 The `PreprocessingLayer` class is the base class you would subclass to 

54 implement your own preprocessing layers. 

55 

56 Attributes: 

57 streaming: Whether a layer can be adapted multiple times without resetting 

58 the state of the layer. 

59 """ 

60 _must_restore_from_config = True 

61 

62 def __init__(self, streaming=True, **kwargs): 

63 super(PreprocessingLayer, self).__init__(**kwargs) 

64 self._streaming = streaming 

65 self._is_compiled = False 

66 self._is_adapted = False 

67 

68 # Sets `is_adapted=False` when `reset_state` is called. 

69 self._reset_state_impl = self.reset_state 

70 self.reset_state = self._reset_state_wrapper 

71 

72 self._adapt_function = None 

73 

74 @property 

75 def streaming(self): 

76 """Whether `adapt` can be called twice without resetting the state.""" 

77 return self._streaming 

78 

79 @property 

80 def is_adapted(self): 

81 """Whether the layer has been fit to data already.""" 

82 return self._is_adapted 

83 

84 def update_state(self, data): 

85 """Accumulates statistics for the preprocessing layer. 

86 

87 Arguments: 

88 data: A mini-batch of inputs to the layer. 

89 """ 

90 raise NotImplementedError 

91 

92 def reset_state(self): # pylint: disable=method-hidden 

93 """Resets the statistics of the preprocessing layer.""" 

94 raise NotImplementedError 

95 

96 def merge_state(self, layers): 

97 """Merge the statistics of multiple preprocessing layers. 

98 

99 This layer will contain the merged state. 

100 

101 Arguments: 

102 layers: Layers whose statistics should be merge with the statistics of 

103 this layer. 

104 """ 

105 raise NotImplementedError 

106 

107 def finalize_state(self): 

108 """Finalize the statistics for the preprocessing layer. 

109 

110 This method is called at the end of `adapt` or after restoring a serialized 

111 preprocessing layer's state. This method handles any one-time operations 

112 that should occur on the layer's state before `Layer.__call__`. 

113 """ 

114 pass 

115 

116 def make_adapt_function(self): 

117 """Creates a function to execute one step of `adapt`. 

118 

119 This method can be overridden to support custom adapt logic. 

120 This method is called by `PreprocessingLayer.adapt`. 

121 

122 Typically, this method directly controls `tf.function` settings, 

123 and delegates the actual state update logic to 

124 `PreprocessingLayer.update_state`. 

125 

126 This function is cached the first time `PreprocessingLayer.adapt` 

127 is called. The cache is cleared whenever `PreprocessingLayer.compile` 

128 is called. 

129 

130 Returns: 

131 Function. The function created by this method should accept a 

132 `tf.data.Iterator`, retrieve a batch, and update the state of the 

133 layer. 

134 """ 

135 if self._adapt_function is not None: 

136 return self._adapt_function 

137 

138 def adapt_step(iterator): 

139 data = next(iterator) 

140 self._adapt_maybe_build(data) 

141 self.update_state(data) 

142 

143 if self._steps_per_execution.numpy().item() == 1: 

144 adapt_fn = adapt_step 

145 else: 

146 

147 def adapt_fn(iterator): 

148 for _ in math_ops.range(self._steps_per_execution): 

149 adapt_step(iterator) 

150 

151 if not self._run_eagerly: 

152 adapt_fn = def_function.function(adapt_fn) 

153 

154 self._adapt_function = adapt_fn 

155 return self._adapt_function 

156 

157 def compile(self, run_eagerly=None, steps_per_execution=None): 

158 """Configures the layer for `adapt`. 

159 

160 Arguments: 

161 run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s logic 

162 will not be wrapped in a `tf.function`. Recommended to leave this as 

163 `None` unless your `Model` cannot be run inside a `tf.function`. 

164 steps_per_execution: Int. Defaults to 1. The number of batches to run 

165 during each `tf.function` call. Running multiple batches inside a 

166 single `tf.function` call can greatly improve performance on TPUs or 

167 small models with a large Python overhead. 

168 """ 

169 if steps_per_execution is None: 

170 steps_per_execution = 1 

171 self._configure_steps_per_execution(steps_per_execution) 

172 

173 if run_eagerly is None: 

174 run_eagerly = self.dynamic 

175 self._run_eagerly = run_eagerly 

176 

177 self._is_compiled = True 

178 

179 def adapt(self, data, batch_size=None, steps=None, reset_state=True): 

180 """Fits the state of the preprocessing layer to the data being passed. 

181 

182 After calling `adapt` on a layer, a preprocessing layer's state will not 

183 update during training. In order to make preprocessing layers efficient in 

184 any distribution context, they are kept constant with respect to any 

185 compiled `tf.Graph`s that call the layer. This does not affect the layer use 

186 when adapting each layer only once, but if you adapt a layer multiple times 

187 you will need to take care to re-compile any compiled functions as follows: 

188 

189 * If you are adding a preprocessing layer to a `keras.Model`, you need to 

190 call `model.compile` after each subsequent call to `adapt`. 

191 * If you are calling a preprocessing layer inside `tf.data.Dataset.map`, 

192 you should call `map` again on the input `tf.data.Dataset` after each 

193 `adapt`. 

194 * If you are using a `tf.function` directly which calls a preprocessing 

195 layer, you need to call `tf.function` again on your callable after 

196 each subsequent call to `adapt`. 

197 

198 `tf.keras.Model` example with multiple adapts: 

199 

200 >>> layer = tf.keras.layers.experimental.preprocessing.Normalization( 

201 ... axis=None) 

202 >>> layer.adapt([0, 2]) 

203 >>> model = tf.keras.Sequential(layer) 

204 >>> model.predict([0, 1, 2]) 

205 array([-1., 0., 1.], dtype=float32) 

206 >>> layer.adapt([-1, 1]) 

207 >>> model.compile() # This is needed to re-compile model.predict! 

208 >>> model.predict([0, 1, 2]) 

209 array([0., 1., 2.], dtype=float32) 

210 

211 `tf.data.Dataset` example with multiple adapts: 

212 

213 >>> layer = tf.keras.layers.experimental.preprocessing.Normalization( 

214 ... axis=None) 

215 >>> layer.adapt([0, 2]) 

216 >>> input_ds = tf.data.Dataset.range(3) 

217 >>> normalized_ds = input_ds.map(layer) 

218 >>> list(normalized_ds.as_numpy_iterator()) 

219 [array([-1.], dtype=float32), 

220 array([0.], dtype=float32), 

221 array([1.], dtype=float32)] 

222 >>> layer.adapt([-1, 1]) 

223 >>> normalized_ds = input_ds.map(layer) # Re-map over the input dataset. 

224 >>> list(normalized_ds.as_numpy_iterator()) 

225 [array([0.], dtype=float32), 

226 array([1.], dtype=float32), 

227 array([2.], dtype=float32)] 

228 

229 Arguments: 

230 data: The data to train on. It can be passed either as a tf.data 

231 Dataset, or as a numpy array. 

232 batch_size: Integer or `None`. 

233 Number of samples per state update. 

234 If unspecified, `batch_size` will default to 32. 

235 Do not specify the `batch_size` if your data is in the 

236 form of datasets, generators, or `keras.utils.Sequence` instances 

237 (since they generate batches). 

238 steps: Integer or `None`. 

239 Total number of steps (batches of samples) 

240 When training with input tensors such as 

241 TensorFlow data tensors, the default `None` is equal to 

242 the number of samples in your dataset divided by 

243 the batch size, or 1 if that cannot be determined. If x is a 

244 `tf.data` dataset, and 'steps' is None, the epoch will run until 

245 the input dataset is exhausted. When passing an infinitely 

246 repeating dataset, you must specify the `steps` argument. This 

247 argument is not supported with array inputs. 

248 reset_state: Optional argument specifying whether to clear the state of 

249 the layer at the start of the call to `adapt`, or whether to start 

250 from the existing state. This argument may not be relevant to all 

251 preprocessing layers: a subclass of PreprocessingLayer may choose to 

252 throw if 'reset_state' is set to False. 

253 """ 

254 _disallow_inside_tf_function('adapt') 

255 if not version_utils.should_use_v2(): 

256 raise RuntimeError('`adapt` is only supported in tensorflow v2.') # pylint: disable=g-doc-exception 

257 if not self.streaming and self._is_adapted and not reset_state: 

258 raise ValueError('{} does not supporting calling `adapt` twice without ' 

259 'resetting the state.'.format(self.__class__.__name__)) 

260 if not self._is_compiled: 

261 self.compile() # Compile with defaults. 

262 if self.built and reset_state: 

263 self.reset_state() 

264 data_handler = data_adapter.DataHandler( 

265 data, 

266 batch_size=batch_size, 

267 steps_per_epoch=steps, 

268 epochs=1, 

269 steps_per_execution=self._steps_per_execution, 

270 distribute=False) 

271 self._adapt_function = self.make_adapt_function() 

272 for _, iterator in data_handler.enumerate_epochs(): 

273 with data_handler.catch_stop_iteration(): 

274 for _ in data_handler.steps(): 

275 self._adapt_function(iterator) 

276 if data_handler.should_sync: 

277 context.async_wait() 

278 self.finalize_state() 

279 self._is_adapted = True 

280 

281 def _reset_state_wrapper(self): 

282 """Calls `reset_state` and sets `adapted` to `False`.""" 

283 self._reset_state_impl() 

284 self._is_adapted = False 

285 

286 @trackable.no_automatic_dependency_tracking 

287 def _configure_steps_per_execution(self, steps_per_execution): 

288 self._steps_per_execution = variables.Variable( 

289 steps_per_execution, 

290 dtype='int64', 

291 aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA) 

292 

293 # TODO(omalleyt): Unify this logic with `Layer._maybe_build`. 

294 def _adapt_maybe_build(self, data): 

295 if not self.built: 

296 try: 

297 # If this is a Numpy array or tensor, we can get shape from .shape. 

298 # If not, an attribute error will be thrown. 

299 data_shape = data.shape 

300 data_shape_nones = tuple([None] * len(data.shape)) 

301 except AttributeError: 

302 # The input has an unknown number of dimensions. 

303 data_shape = None 

304 data_shape_nones = None 

305 

306 # TODO (b/159261555): move this to base layer build. 

307 batch_input_shape = getattr(self, '_batch_input_shape', None) 

308 if batch_input_shape is None: 

309 # Set the number of dimensions. 

310 self._batch_input_shape = data_shape_nones 

311 self.build(data_shape) 

312 self.built = True 

313 

314 

315# TODO(omalleyt): This class will be gradually replaced. 

316class CombinerPreprocessingLayer(PreprocessingLayer): 

317 """Base class for PreprocessingLayers that do computation using a Combiner. 

318 

319 This class provides several helper methods to make creating a 

320 PreprocessingLayer easier. It assumes that the core of your computation will 

321 be done via a Combiner object. Subclassing this class to create a 

322 PreprocessingLayer allows your layer to be compatible with distributed 

323 computation. 

324 

325 This class is compatible with Tensorflow 2.0+. 

326 """ 

327 

328 def __init__(self, combiner, **kwargs): 

329 super(CombinerPreprocessingLayer, self).__init__(**kwargs) 

330 self.state_variables = collections.OrderedDict() 

331 self._combiner = combiner 

332 self._adapt_accumulator = None 

333 

334 def reset_state(self): # pylint: disable=method-hidden 

335 self._adapt_accumulator = None 

336 

337 @trackable.no_automatic_dependency_tracking 

338 def update_state(self, data): 

339 if self._adapt_accumulator is None: 

340 self._adapt_accumulator = self._get_accumulator() 

341 self._adapt_accumulator = self._combiner.compute(data, 

342 self._adapt_accumulator) 

343 

344 def merge_state(self, layers): 

345 accumulators = ([self._get_accumulator()] + 

346 [l._get_accumulator() for l in layers]) # pylint: disable=protected-access 

347 merged_accumulator = self._combiner.merge(accumulators) 

348 self._set_accumulator(merged_accumulator) 

349 

350 def finalize_state(self): 

351 if self._adapt_accumulator is not None: 

352 self._set_accumulator(self._adapt_accumulator) 

353 

354 def compile(self, run_eagerly=None, steps_per_execution=None): 

355 # TODO(omalleyt): Remove this once sublayers are switched to new APIs. 

356 if run_eagerly is None: 

357 run_eagerly = True 

358 super(CombinerPreprocessingLayer, self).compile( 

359 run_eagerly=run_eagerly, steps_per_execution=steps_per_execution) 

360 

361 def adapt(self, data, batch_size=None, steps=None, reset_state=True): 

362 if not reset_state: 

363 self._adapt_accumulator = self._combiner.restore(self._restore_updates()) 

364 super(CombinerPreprocessingLayer, self).adapt( 

365 data, batch_size=batch_size, steps=steps, reset_state=reset_state) 

366 

367 def _add_state_variable(self, 

368 name, 

369 shape, 

370 dtype, 

371 initializer=None, 

372 partitioner=None, 

373 use_resource=None, 

374 **kwargs): 

375 """Add a variable that can hold state which is updated during adapt(). 

376 

377 Args: 

378 name: Variable name. 

379 shape: Variable shape. Defaults to scalar if unspecified. 

380 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 

381 initializer: initializer instance (callable). 

382 partitioner: Partitioner to be passed to the `Trackable` API. 

383 use_resource: Whether to use `ResourceVariable` 

384 **kwargs: Additional keyword arguments. Accepted values are `getter` and 

385 `collections`. 

386 

387 Returns: 

388 The created variable. 

389 """ 

390 weight = self.add_weight( 

391 name=name, 

392 shape=shape, 

393 dtype=dtype, 

394 initializer=initializer, 

395 regularizer=None, 

396 trainable=False, 

397 constraint=None, 

398 partitioner=partitioner, 

399 use_resource=use_resource, 

400 **kwargs) 

401 # TODO(momernick): Do not allow collisions here. 

402 self.state_variables[name] = weight 

403 return weight 

404 

405 def _restore_updates(self): 

406 """Recreates a dict of updates from the layer's weights.""" 

407 data_dict = {} 

408 for name, var in self.state_variables.items(): 

409 data_dict[name] = var.numpy() 

410 return data_dict 

411 

412 def _get_accumulator(self): 

413 if self._is_adapted: 

414 return self._combiner.restore(self._restore_updates()) 

415 else: 

416 return None 

417 

418 def _set_accumulator(self, accumulator): 

419 updates = self._combiner.extract(accumulator) 

420 self._set_state_variables(updates) 

421 self._adapt_accumulator = None # Reset accumulator from adapt. 

422 

423 def _set_state_variables(self, updates): 

424 """Directly update the internal state of this Layer. 

425 

426 This method expects a string-keyed dict of {state_variable_name: state}. The 

427 precise nature of the state, and the names associated, are describe by 

428 the subclasses of CombinerPreprocessingLayer. 

429 

430 Args: 

431 updates: A string keyed dict of weights to update. 

432 

433 Raises: 

434 RuntimeError: if 'build()' was not called before 'set_processing_state'. 

435 """ 

436 # TODO(momernick): Do we need to do any more input sanitization? 

437 if not self.built: 

438 raise RuntimeError('_set_state_variables() must be called after build().') 

439 

440 with ops.init_scope(): 

441 for var_name, value in updates.items(): 

442 self.state_variables[var_name].assign(value) 

443 

444 

445def convert_to_list(values, sparse_default_value=None): 

446 """Convert a TensorLike, CompositeTensor, or ndarray into a Python list.""" 

447 if tf_utils.is_ragged(values): 

448 # There is a corner case when dealing with ragged tensors: if you get an 

449 # actual RaggedTensor (not a RaggedTensorValue) passed in non-eager mode, 

450 # you can't call to_list() on it without evaluating it first. However, 

451 # because we don't yet fully support composite tensors across Keras, 

452 # backend.get_value() won't evaluate the tensor. 

453 # TODO(momernick): Get Keras to recognize composite tensors as Tensors 

454 # and then replace this with a call to backend.get_value. 

455 if (isinstance(values, ragged_tensor.RaggedTensor) and 

456 not context.executing_eagerly()): 

457 values = backend.get_session(values).run(values) 

458 values = values.to_list() 

459 

460 if isinstance(values, 

461 (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): 

462 if sparse_default_value is None: 

463 if dtypes.as_dtype(values.values.dtype) == dtypes.string: 

464 sparse_default_value = '' 

465 else: 

466 sparse_default_value = -1 

467 dense_tensor = sparse_ops.sparse_tensor_to_dense( 

468 values, default_value=sparse_default_value) 

469 values = backend.get_value(dense_tensor) 

470 

471 if isinstance(values, ops.Tensor): 

472 values = backend.get_value(values) 

473 

474 # We may get passed a ndarray or the code above may give us a ndarray. 

475 # In either case, we want to force it into a standard python list. 

476 if isinstance(values, np.ndarray): 

477 values = values.tolist() 

478 

479 return values 

480 

481 

482# TODO(omalleyt): This class will be gradually replaced. 

483class Combiner(object): 

484 """Functional object that defines a shardable computation. 

485 

486 This object defines functions required to create and manipulate data objects. 

487 These data objects, referred to below as 'accumulators', are computation- 

488 specific and may be implemented alongside concrete subclasses of Combiner 

489 (if necessary - some computations may be simple enough that standard Python 

490 types can be used as accumulators). 

491 

492 The intent for this class is that by describing computations in this way, we 

493 can arbitrarily shard a dataset, perform computations on a subset, and then 

494 merge the computation into a final result. This enables distributed 

495 computation. 

496 

497 The combiner itself does not own any state - all computational state is owned 

498 by the accumulator objects. This is so that we can have an arbitrary number of 

499 Combiners (thus sharding the computation N ways) without risking any change 

500 to the underlying computation. These accumulator objects are uniquely 

501 associated with each Combiner; a Combiner defines what the accumulator object 

502 should be and will only work with accumulators of that type. 

503 """ 

504 __metaclass__ = abc.ABCMeta 

505 

506 def __repr__(self): 

507 return '<{}>'.format(self.__class__.__name__) 

508 

509 @abc.abstractmethod 

510 def compute(self, batch_values, accumulator=None): 

511 """Compute a step in this computation, returning a new accumulator. 

512 

513 This method computes a step of the computation described by this Combiner. 

514 If an accumulator is passed, the data in that accumulator is also used; so 

515 compute(batch_values) results in f(batch_values), while 

516 compute(batch_values, accumulator) results in 

517 merge(f(batch_values), accumulator). 

518 

519 Args: 

520 batch_values: A list of ndarrays representing the values of the inputs for 

521 this step of the computation. 

522 accumulator: the current accumulator. Can be None. 

523 

524 Returns: 

525 An accumulator that includes the passed batch of inputs. 

526 """ 

527 pass 

528 

529 @abc.abstractmethod 

530 def merge(self, accumulators): 

531 """Merge several accumulators to a single accumulator. 

532 

533 This method takes the partial values in several accumulators and combines 

534 them into a single accumulator. This computation must not be order-specific 

535 (that is, merge([a, b]) must return the same result as merge([b, a]). 

536 

537 Args: 

538 accumulators: the accumulators to merge, as a list. 

539 

540 Returns: 

541 A merged accumulator. 

542 """ 

543 pass 

544 

545 @abc.abstractmethod 

546 def extract(self, accumulator): 

547 """Convert an accumulator into a dict of output values. 

548 

549 Args: 

550 accumulator: The accumulator to convert. 

551 

552 Returns: 

553 A dict of ndarrays representing the data in this accumulator. 

554 """ 

555 pass 

556 

557 @abc.abstractmethod 

558 def restore(self, output): 

559 """Create an accumulator based on 'output'. 

560 

561 This method creates a new accumulator with identical internal state to the 

562 one used to create the data in 'output'. This means that if you do 

563 

564 output_data = combiner.extract(accumulator_1) 

565 accumulator_2 = combiner.restore(output_data) 

566 

567 then accumulator_1 and accumulator_2 will have identical internal state, and 

568 computations using either of them will be equivalent. 

569 

570 Args: 

571 output: The data output from a previous computation. Should be in the same 

572 form as provided by 'extract_output'. 

573 

574 Returns: 

575 A new accumulator. 

576 """ 

577 pass 

578 

579 @abc.abstractmethod 

580 def serialize(self, accumulator): 

581 """Serialize an accumulator for a remote call. 

582 

583 This function serializes an accumulator to be sent to a remote process. 

584 

585 Args: 

586 accumulator: The accumulator to serialize. 

587 

588 Returns: 

589 A byte string representing the passed accumulator. 

590 """ 

591 pass 

592 

593 @abc.abstractmethod 

594 def deserialize(self, encoded_accumulator): 

595 """Deserialize an accumulator received from 'serialize()'. 

596 

597 This function deserializes an accumulator serialized by 'serialize()'. 

598 

599 Args: 

600 encoded_accumulator: A byte string representing an accumulator. 

601 

602 Returns: 

603 The accumulator represented by the passed byte_string. 

604 """ 

605 pass 

606 

607 

608def _disallow_inside_tf_function(method_name): 

609 """Disallow calling a method inside a `tf.function`.""" 

610 if ops.inside_function(): 

611 error_msg = ( 

612 'Detected a call to `PreprocessingLayer.{method_name}` inside a ' 

613 '`tf.function`. `PreprocessingLayer.{method_name} is a high-level ' 

614 'endpoint that manages its own `tf.function`. Please move the call ' 

615 'to `PreprocessingLayer.{method_name}` outside of all enclosing ' 

616 '`tf.function`s. Note that you can call a `PreprocessingLayer` ' 

617 'directly on `Tensor`s inside a `tf.function` like: `layer(x)`, ' 

618 'or update its state like: `layer.update_state(x)`.').format( 

619 method_name=method_name) 

620 raise RuntimeError(error_msg)