Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_utils_v1.py: 17%

782 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Training-related utilities.""" 

16 

17import abc 

18import atexit 

19import collections 

20import functools 

21import multiprocessing.pool 

22import threading 

23import time 

24 

25import numpy as np 

26 

27from tensorflow.core.framework import graph_pb2 

28from tensorflow.python import tf2 

29from tensorflow.python.data.experimental.ops import cardinality 

30from tensorflow.python.data.ops import dataset_ops 

31from tensorflow.python.data.ops import iterator_ops 

32from tensorflow.python.data.ops import options as options_lib 

33from tensorflow.python.eager import context 

34from tensorflow.python.framework import composite_tensor 

35from tensorflow.python.framework import dtypes 

36from tensorflow.python.framework import errors 

37from tensorflow.python.framework import smart_cond 

38from tensorflow.python.framework import sparse_tensor 

39from tensorflow.python.framework import tensor_conversion 

40from tensorflow.python.framework import tensor_spec 

41from tensorflow.python.framework import tensor_util 

42from tensorflow.python.keras import backend 

43from tensorflow.python.keras import callbacks as cbks 

44from tensorflow.python.keras import losses 

45from tensorflow.python.keras import metrics as metrics_module 

46from tensorflow.python.keras.utils import data_utils 

47from tensorflow.python.keras.utils import generic_utils 

48from tensorflow.python.keras.utils import losses_utils 

49from tensorflow.python.keras.utils import tf_inspect 

50from tensorflow.python.ops import array_ops 

51from tensorflow.python.ops import gen_array_ops 

52from tensorflow.python.ops import math_ops 

53from tensorflow.python.ops import sparse_ops 

54from tensorflow.python.ops.ragged import ragged_tensor 

55from tensorflow.python.ops.ragged import ragged_tensor_value 

56from tensorflow.python.platform import tf_logging as logging 

57from tensorflow.python.types import data as data_types 

58from tensorflow.python.util import nest 

59 

60 

61def is_composite_or_composite_value(tensor): 

62 """Returns true if 'tensor' is a CompositeTensor or a CT Value object.""" 

63 # TODO(b/125094323): This should be isinstance(CompositeTensor) or 

64 # isinstance(CompositeTensorValue) once we support that. 

65 return isinstance( 

66 tensor, 

67 (composite_tensor.CompositeTensor, sparse_tensor.SparseTensorValue, 

68 ragged_tensor_value.RaggedTensorValue)) 

69 

70 

71class Aggregator(object, metaclass=abc.ABCMeta): 

72 """Abstract base class used to aggregate batch-level outputs of a loop. 

73 

74 Attributes: 

75 use_steps: Whether the loop is using `step` or `batch_size`. 

76 num_samples: Total number of samples: `batch_size * num_batches`. 

77 steps: Total number of steps. 

78 batch_size: Batch size. It is used for validation checks between inputs and 

79 outputs. 

80 results: What to return at the end of the aggregation loop. 

81 """ 

82 

83 def __init__(self, use_steps, num_samples=None, steps=None, batch_size=None): 

84 self.use_steps = use_steps 

85 self.num_samples = num_samples 

86 self.steps = steps 

87 self.batch_size = batch_size 

88 self.results = [] 

89 

90 @abc.abstractmethod 

91 def create(self, batch_outs): 

92 """Creates the initial results from the first batch outputs. 

93 

94 Args: 

95 batch_outs: A list of batch-level outputs. 

96 """ 

97 raise NotImplementedError('Must be implemented in subclasses.') 

98 

99 @abc.abstractmethod 

100 def aggregate(self, batch_outs, batch_start=None, batch_end=None): 

101 """Aggregates batch-level results into total results. 

102 

103 Args: 

104 batch_outs: A list of batch-level outputs. 

105 batch_start: The start index of this batch. Always `None` if `use_steps` 

106 is `True`. 

107 batch_end: The end index of this batch. Always `None` if `use_steps` is 

108 `True`. 

109 """ 

110 raise NotImplementedError('Must be implemented in subclasses.') 

111 

112 @abc.abstractmethod 

113 def finalize(self): 

114 """Prepares the total results to be returned.""" 

115 raise NotImplementedError('Must be implemented in subclasses.') 

116 

117 

118class MetricsAggregator(Aggregator): 

119 """Aggregator that calculates loss and metrics info. 

120 

121 Attributes: 

122 use_steps: Whether the loop is using `step` or `batch_size`. 

123 num_samples: Total number of samples: `batch_size*num_batches`. 

124 steps: Total number of steps, ie number of times to iterate over a dataset 

125 to cover all samples. 

126 """ 

127 

128 def __init__(self, use_steps, num_samples=None, steps=None): 

129 super(MetricsAggregator, self).__init__( 

130 use_steps=use_steps, 

131 num_samples=num_samples, 

132 steps=steps, 

133 batch_size=None) 

134 

135 def create(self, batch_outs): 

136 self.results = [0.] * len(batch_outs) 

137 

138 def aggregate(self, batch_outs, batch_start=None, batch_end=None): 

139 # Loss. 

140 if self.use_steps: 

141 self.results[0] += batch_outs[0] 

142 else: 

143 self.results[0] += batch_outs[0] * (batch_end - batch_start) 

144 # Metrics (always stateful, just grab current values.) 

145 self.results[1:] = batch_outs[1:] 

146 

147 def finalize(self): 

148 if not self.results: 

149 raise ValueError('Empty training data.') 

150 self.results[0] /= (self.num_samples or self.steps) 

151 

152 

153def _append_sparse_tensor_value(target, to_append): 

154 """Append sparse tensor value objects.""" 

155 # Make sure the sparse tensors are of the same size (except for the 0th dim). 

156 if len(target.dense_shape) != len(to_append.dense_shape): 

157 raise RuntimeError( 

158 'Unable to concatenate %s and %s. The inner dense shapes do not ' 

159 'have the same number of dimensions (%s vs %s)' % 

160 (target, to_append, target.dense_shape, to_append.dense_shape)) 

161 

162 if target.dense_shape[1:] != to_append.dense_shape[1:]: 

163 raise RuntimeError( 

164 'Unable to concatenate %s and %s. The inner dense shapes do not ' 

165 'match inner dimensions (%s vs %s)' % 

166 (target, to_append, target.dense_shape[1:], to_append.dense_shape[1:])) 

167 

168 # Add the to_append indices to target, updating the 0th value, and keeping 

169 # track of the maximum so we know the final dense_shape of this tensor. 

170 base_dim0_value = target.dense_shape[0] 

171 max_dim0_value = target.dense_shape[0] 

172 new_indices = target.indices 

173 for index in to_append.indices: 

174 # Here, we iterate through the sparse indices of the tensor to append. For 

175 # each index, we update its zeroth value (the batch index) by adding the 

176 # number of batch items in the tensor we are appending to (so an index 

177 # of [0, 0, 1] for a value that is being appended to a tensor with 0th dim 

178 # size 3 would become [3, 0, 1].) 

179 index[0] += base_dim0_value 

180 max_dim0_value = max(max_dim0_value, index[0]) 

181 new_indices = np.append(new_indices, [index], axis=0) 

182 

183 # Extend the values array to contain all of the appended values. These will 

184 # be in the same order as the indices added above. 

185 new_values = np.concatenate((target.values, to_append.values), axis=0) 

186 

187 # Create a new dense shape by replacing the value for the 0th dimension 

188 # with the new max dim0 value. 

189 new_dense_shape = list(target.dense_shape) 

190 new_dense_shape[0] = max_dim0_value + 1 

191 new_dense_shape = tuple(new_dense_shape) 

192 

193 return sparse_tensor.SparseTensorValue( 

194 indices=new_indices, values=new_values, dense_shape=new_dense_shape) 

195 

196 

197def _append_ragged_tensor_value(target, to_append): 

198 """Append ragged tensor value objects.""" 

199 # Make sure the ragged tensors are of the same size (save for the 0th dim). 

200 if len(target.shape) != len(to_append.shape): 

201 raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append)) 

202 

203 if target.shape[1:] != to_append.shape[1:]: 

204 raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append)) 

205 

206 adjusted_row_splits = to_append.row_splits[1:] + target.row_splits[-1] 

207 new_row_splits = np.append(target.row_splits, adjusted_row_splits) 

208 if isinstance(target.values, ragged_tensor_value.RaggedTensorValue): 

209 new_values = _append_ragged_tensor_value(target.values, to_append.values) 

210 else: 

211 new_values = np.concatenate((target.values, to_append.values), axis=0) 

212 

213 return ragged_tensor_value.RaggedTensorValue(new_values, new_row_splits) 

214 

215 

216def _append_composite_tensor(target, to_append): 

217 """Helper function to append composite tensors to each other in the 0 axis. 

218 

219 In order to support batching within a fit/evaluate/predict call, we need 

220 to be able to aggregate within a CompositeTensor. Unfortunately, the CT 

221 API currently does not make this easy - especially in V1 mode, where we're 

222 working with CompositeTensor Value objects that have no connection with the 

223 CompositeTensors that created them. 

224 

225 Args: 

226 target: CompositeTensor or CompositeTensor value object that will be 

227 appended to. 

228 to_append: CompositeTensor or CompositeTensor value object to append to. 

229 'target'. 

230 

231 Returns: 

232 A CompositeTensor or CompositeTensor value object. 

233 

234 Raises: 

235 RuntimeError: if concatenation is not possible. 

236 """ 

237 if type(target) is not type(to_append): 

238 raise RuntimeError('Unable to concatenate %s and %s' % 

239 (type(target), type(to_append))) 

240 

241 # Perform type-specific concatenation. 

242 # TODO(b/125094323): This should be replaced by a simple call to 

243 # target.append() that should work on all of the below classes. 

244 

245 # If we're seeing a CompositeTensor here, we know it's because we're in 

246 # Eager mode (or else we'd have evaluated the CT to a CT Value object 

247 # already). Therefore, it's safe to call concat() on it without evaluating 

248 # the result any further. If not - that is, if we're seeing a 

249 # SparseTensorValue or a RaggedTensorValue - we need to hand-update it 

250 # since we're outside of the graph anyways. 

251 if isinstance(target, sparse_tensor.SparseTensor): 

252 # We need to invoke the sparse version of concatenate here - tf.concat 

253 # won't work. 

254 return sparse_ops.sparse_concat(sp_inputs=[target, to_append], axis=0) 

255 elif isinstance(target, ragged_tensor.RaggedTensor): 

256 return array_ops.concat([target, to_append], axis=0) 

257 elif isinstance(target, sparse_tensor.SparseTensorValue): 

258 return _append_sparse_tensor_value(target, to_append) 

259 elif isinstance(target, ragged_tensor_value.RaggedTensorValue): 

260 return _append_ragged_tensor_value(target, to_append) 

261 else: 

262 raise RuntimeError('Attempted to concatenate unsupported object %s.' % 

263 type(target)) 

264 

265 

266class ConcatAggregator(Aggregator): 

267 """Combine tensor-likes which cannot be merged on the fly. 

268 

269 This class expects to aggregate a single tensor-like rather than a nested 

270 structure of tensor-likes. 

271 """ 

272 

273 def __init__(self, batch_size): 

274 self.composite = None 

275 super(ConcatAggregator, self).__init__( 

276 use_steps=True, num_samples=None, steps=None, batch_size=batch_size) 

277 

278 def create(self, batch_element): 

279 self.composite = is_composite_or_composite_value(batch_element) 

280 

281 def aggregate(self, batch_element, batch_start=None, batch_end=None): 

282 

283 # TODO(psv): Add num_samples check here to detect when output batch 

284 # #samples is < batch size and != input batch #samples. 

285 if self.batch_size and self.batch_size < batch_element.shape[0]: 

286 raise ValueError( 

287 'Mismatch between expected batch size and model output batch size. ' 

288 'Output shape = {}, expected output shape = shape {}'.format( 

289 batch_element.shape, 

290 (self.batch_size,) + batch_element.shape[1:])) 

291 self.results.append(batch_element) 

292 

293 def finalize(self): 

294 # Special case of single batch inference which skips a copy. 

295 if len(self.results) == 1: 

296 self.results = self.results[0] 

297 

298 elif self.composite: 

299 # TODO(taylorrobie): efficiently concatenate. 

300 results = self.results[0] 

301 for r in self.results[1:]: 

302 results = _append_composite_tensor(results, r) 

303 self.results = results 

304 

305 else: 

306 self.results = np.concatenate(self.results, axis=0) 

307 

308 

309_COPY_THREADS = 4 

310_COPY_POOL = None 

311 

312 

313def get_copy_pool(): 

314 """Shared threadpool for copying arrays. 

315 

316 Pool instantiation takes ~ 2ms, so a singleton pool is used rather than 

317 creating a pool per SliceAggregator. 

318 

319 Returns: 

320 The global copy threadpool. 

321 """ 

322 global _COPY_POOL 

323 if _COPY_POOL is None: 

324 _COPY_POOL = multiprocessing.pool.ThreadPool(_COPY_THREADS) 

325 atexit.register(_COPY_POOL.close) 

326 return _COPY_POOL 

327 

328 

329class SliceAggregator(Aggregator): 

330 """Combine arrays where the final size is known. 

331 

332 This class expects to aggregate a single tensor-like rather than a nested 

333 structure of tensor-likes. 

334 

335 NumPy copies are an operation that threads handle quite well because all of 

336 the heavy lifting is in c and does not need the GIL. Moreover, we can perform 

337 lock-free writes to the same buffer in multiple threads because the nature of 

338 result aggregation guarantees that either the indices are disjoint or the 

339 aggregator will throw an exception in finalize. Moreover, because aggregation 

340 is performed on the slowest varying dimension, assignments for a given batch 

341 will write to contiguous blocks of memory, further minimizing contention. 

342 

343 There is, however, some scheduling and context switching overhead which will 

344 offset the gains from pipelining the slice assignment. Below a given threshold 

345 it is faster to simply assign in the main thread rather than enqueue the 

346 assignment in a side thread. The exact threshold will vary from system to 

347 system, but the time is not very sensitive to the exact transition so a value 

348 of 2 ** 14 was chosen which should be reasonable on most systems. 

349 """ 

350 

351 _BINARY_SIZE_THRESHOLD = 2 ** 14 

352 _MAX_COPY_SECONDS = 300 

353 

354 def __init__(self, num_samples, batch_size): 

355 self._async_copies = [] 

356 self._pool = get_copy_pool() 

357 self._errors = [] 

358 super(SliceAggregator, self).__init__( 

359 use_steps=False, 

360 num_samples=num_samples, 

361 steps=None, 

362 batch_size=batch_size) 

363 

364 def create(self, batch_element): 

365 # This step does not need to be pipelined because NumPy empty array 

366 # initialization is effectively instantaneous. 

367 shape = (self.num_samples,) + batch_element.shape[1:] 

368 dtype = batch_element.dtype 

369 

370 self.results = np.empty(shape=shape, dtype=dtype) 

371 

372 def aggregate(self, batch_element, batch_start, batch_end): 

373 # Fail early. 

374 if self._errors: 

375 raise self._errors[0] 

376 

377 # In the special case of single batch inference, no copy is needed. 

378 if batch_end - batch_start == self.num_samples: 

379 if self.num_samples != batch_element.shape[0]: 

380 raise ValueError( 

381 'Mismatch between expected batch size and model output batch size. ' 

382 'Output shape = {}, expected output shape = shape {}'.format( 

383 batch_element.shape, self.results.shape)) 

384 

385 self.results = batch_element 

386 return 

387 

388 # This is an approximate threshold, so we don't need to consider the number 

389 # of bytes per element. 

390 num_elements = np.prod(batch_element.shape) 

391 if num_elements < self._BINARY_SIZE_THRESHOLD: 

392 self.results[batch_start:batch_end] = batch_element 

393 else: 

394 is_finished = threading.Event() 

395 self._pool.apply_async( 

396 self._slice_assign, 

397 args=(batch_element, batch_start, batch_end, is_finished)) 

398 self._async_copies.append(is_finished) 

399 

400 def _slice_assign(self, batch_element, batch_start, batch_end, is_finished): 

401 """Legacy utility method to slice input arrays.""" 

402 try: 

403 self.results[batch_start:batch_end] = batch_element 

404 

405 except Exception as e: # pylint: disable=broad-except 

406 # `_slice_assign` should only be called in threads and exceptions raised 

407 # in threads do not carry over to the main thread. So instead we perform a 

408 # a broad catch in the thread and then store the exception to be re-raised 

409 # in the main thread. 

410 self._errors.append(e) 

411 

412 finally: 

413 is_finished.set() 

414 

415 def finalize(self): 

416 start_time = time.time() 

417 for is_finished in self._async_copies: 

418 timeout = max([0., self._MAX_COPY_SECONDS - (time.time() - start_time)]) 

419 if not is_finished.wait(timeout): 

420 raise ValueError('Timed out waiting for copy to complete.') 

421 

422 if self._errors: 

423 raise self._errors[0] 

424 

425 

426class OutputsAggregator(Aggregator): 

427 """Aggregator that concatenates outputs.""" 

428 

429 _structure = None 

430 

431 def create(self, batch_outs): 

432 # SparseTensorValue is a named tuple which nest will flatten, so we need 

433 # to guard it to properly handle the structure. 

434 self._structure = nest.get_traverse_shallow_structure( 

435 lambda x: not is_composite_or_composite_value(x), batch_outs) 

436 batch_outs = nest.flatten_up_to(self._structure, batch_outs) 

437 

438 for batch_element in batch_outs: 

439 if is_composite_or_composite_value(batch_element): 

440 # If the output is not a ndarray, it will be either a composite tensor 

441 # or a composite tensor's Value object. In either case, we can't 

442 # allocate an array to hold the object - we'll handle it later. 

443 self.results.append(ConcatAggregator(self.batch_size)) 

444 elif isinstance(batch_element, np.ndarray): 

445 self.results.append( 

446 (ConcatAggregator(self.batch_size) if self.use_steps else 

447 SliceAggregator(self.num_samples, self.batch_size))) 

448 else: 

449 # This is not a ndarray, a CompositeTensor, or a CompositeTensorValue. 

450 # Fail fast rather than trying to concatenate it. 

451 raise RuntimeError('Attempted to aggregate unsupported object {}.' 

452 .format(batch_element)) 

453 

454 self.results[-1].create(batch_element) 

455 

456 def aggregate(self, batch_outs, batch_start=None, batch_end=None): 

457 batch_outs = nest.flatten_up_to(self._structure, batch_outs) 

458 for batch_element, result in zip(batch_outs, self.results): 

459 result.aggregate(batch_element, batch_start, batch_end) 

460 

461 def finalize(self): 

462 for result in self.results: 

463 result.finalize() 

464 self.results = [i.results for i in self.results] 

465 self.results = nest.pack_sequence_as(self._structure, self.results) 

466 

467 

468def get_progbar(model, count_mode, include_metrics=True): 

469 """Get Progbar.""" 

470 if include_metrics: 

471 stateful_metric_names = getattr(model, 'metrics_names', None) 

472 if stateful_metric_names: 

473 stateful_metric_names = stateful_metric_names[1:] # Exclude `loss` 

474 else: 

475 stateful_metric_names = None 

476 return cbks.ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names) 

477 

478 

479def check_num_samples(ins, batch_size=None, steps=None, steps_name='steps'): 

480 """Determine the number of samples provided for training and evaluation. 

481 

482 The number of samples is not defined when running with `steps`, 

483 in which case the number of samples is set to `None`. 

484 

485 Args: 

486 ins: List of tensors to be fed to the Keras function. 

487 batch_size: Integer batch size or `None` if not defined. 

488 steps: Total number of steps (batches of samples) before declaring 

489 `_predict_loop` finished. Ignored with the default value of `None`. 

490 steps_name: The public API's parameter name for `steps`. 

491 

492 Raises: 

493 ValueError: when `steps` is `None` and the attribute `ins.shape` 

494 does not exist. Also raises ValueError when `steps` is not `None` 

495 and `batch_size` is not `None` because they are mutually 

496 exclusive. 

497 

498 Returns: 

499 When steps is `None`, returns the number of samples to be 

500 processed based on the size of the first dimension of the 

501 first input numpy array. When steps is not `None` and 

502 `batch_size` is `None`, returns `None`. 

503 """ 

504 if steps is not None and batch_size is not None: 

505 raise ValueError('If ' + steps_name + 

506 ' is set, the `batch_size` must be None.') 

507 if check_steps_argument(ins, steps, steps_name): 

508 return None 

509 

510 if hasattr(ins[0], 'shape'): 

511 return int(ins[0].shape[0]) 

512 return None # Edge case where ins == [static_learning_phase] 

513 

514 

515def standardize_single_array(x, expected_shape=None): 

516 """Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1.""" 

517 if x is None: 

518 return None 

519 

520 if is_composite_or_composite_value(x): 

521 return x 

522 

523 if isinstance(x, int): 

524 raise ValueError( 

525 'Expected an array data type but received an integer: {}'.format(x)) 

526 

527 if (x.shape is not None and len(x.shape) == 1 and 

528 (expected_shape is None or len(expected_shape) != 1)): 

529 if tensor_util.is_tf_type(x): 

530 x = array_ops.expand_dims(x, axis=1) 

531 else: 

532 x = np.expand_dims(x, 1) 

533 return x 

534 

535 

536def get_composite_shape(tensor): 

537 """Returns the shape of the passed composite tensor.""" 

538 if isinstance(tensor, sparse_tensor.SparseTensorValue): 

539 # SparseTensorValues use a 'dense_shape' attribute 

540 return tensor.dense_shape 

541 else: 

542 return tensor.shape 

543 

544 

545def standardize_input_data(data, 

546 names, 

547 shapes=None, 

548 check_batch_axis=True, 

549 exception_prefix=''): 

550 """Normalizes inputs and targets provided by users. 

551 

552 Users may pass data as a list of arrays, dictionary of arrays, 

553 or as a single array. We normalize this to an ordered list of 

554 arrays (same order as `names`), while checking that the provided 

555 arrays have shapes that match the network's expectations. 

556 

557 Args: 

558 data: User-provided input data (polymorphic). 

559 names: List of expected array names. 

560 shapes: Optional list of expected array shapes. 

561 check_batch_axis: Boolean; whether to check that the batch axis of the 

562 arrays matches the expected value found in `shapes`. 

563 exception_prefix: String prefix used for exception formatting. 

564 

565 Returns: 

566 List of standardized input arrays (one array per model input). 

567 

568 Raises: 

569 ValueError: in case of improperly formatted user-provided data. 

570 """ 

571 try: 

572 data_len = len(data) 

573 except TypeError: 

574 # For instance if data is `None` or a symbolic Tensor. 

575 data_len = None 

576 

577 if not names: 

578 if data_len and not isinstance(data, dict): 

579 raise ValueError( 

580 'Error when checking model ' + exception_prefix + ': ' 

581 'expected no data, but got:', data) 

582 return [] 

583 if data is None: 

584 return [None for _ in range(len(names))] 

585 

586 if isinstance(data, dict): 

587 try: 

588 data = [ 

589 data[x].values 

590 if data[x].__class__.__name__ == 'DataFrame' else data[x] 

591 for x in names 

592 ] 

593 except KeyError as e: 

594 raise ValueError('No data provided for "' + e.args[0] + '". Need data ' 

595 'for each key in: ' + str(names)) 

596 elif isinstance(data, (list, tuple)): 

597 if isinstance(data[0], (list, tuple)): 

598 data = [np.asarray(d) for d in data] 

599 elif len(names) == 1 and isinstance(data[0], (float, int)): 

600 data = [np.asarray(data)] 

601 else: 

602 data = [ 

603 x.values if x.__class__.__name__ == 'DataFrame' else x for x in data 

604 ] 

605 else: 

606 data = data.values if data.__class__.__name__ == 'DataFrame' else data 

607 data = [data] 

608 

609 if shapes is not None: 

610 data = [ 

611 standardize_single_array(x, shape) for (x, shape) in zip(data, shapes) 

612 ] 

613 else: 

614 data = [standardize_single_array(x) for x in data] 

615 

616 if len(data) != len(names): 

617 if data and hasattr(data[0], 'shape'): 

618 raise ValueError('Error when checking model ' + exception_prefix + 

619 ': the list of Numpy arrays that you are passing to ' 

620 'your model is not the size the model expected. ' 

621 'Expected to see ' + str(len(names)) + ' array(s), ' + 

622 'for inputs ' + str(names) + ' but instead got the ' 

623 'following list of ' + str(len(data)) + ' arrays: ' + 

624 str(data)[:200] + '...') 

625 elif len(names) > 1: 

626 raise ValueError('Error when checking model ' + exception_prefix + 

627 ': you are passing a list as input to your model, ' 

628 'but the model expects a list of ' + str(len(names)) + 

629 ' Numpy arrays instead. The list you passed was: ' + 

630 str(data)[:200]) 

631 elif len(data) == 1 and not hasattr(data[0], 'shape'): 

632 raise TypeError('Error when checking model ' + exception_prefix + 

633 ': data should be a Numpy array, or list/dict of ' 

634 'Numpy arrays. Found: ' + str(data)[:200] + '...') 

635 elif len(names) == 1: 

636 data = [np.asarray(data)] 

637 

638 # Check shapes compatibility. 

639 if shapes: 

640 for i in range(len(names)): 

641 if shapes[i] is not None: 

642 if tensor_util.is_tf_type(data[i]): 

643 tensorshape = data[i].shape 

644 if not tensorshape: 

645 continue 

646 data_shape = tuple(tensorshape.as_list()) 

647 elif is_composite_or_composite_value(data[i]): 

648 tensorshape = get_composite_shape(data[i]) 

649 data_shape = tuple(tensorshape.as_list()) 

650 else: 

651 data_shape = data[i].shape 

652 

653 shape = shapes[i] 

654 if len(data_shape) != len(shape): 

655 raise ValueError('Error when checking ' + exception_prefix + 

656 ': expected ' + names[i] + ' to have ' + 

657 str(len(shape)) + ' dimensions, but got array ' 

658 'with shape ' + str(data_shape)) 

659 if not check_batch_axis: 

660 data_shape = data_shape[1:] 

661 shape = shape[1:] 

662 for dim, ref_dim in zip(data_shape, shape): 

663 if ref_dim != dim and ref_dim is not None and dim is not None: 

664 raise ValueError('Error when checking ' + exception_prefix + 

665 ': expected ' + names[i] + ' to have shape ' + 

666 str(shape) + ' but got array with shape ' + 

667 str(data_shape)) 

668 return data 

669 

670 

671def standardize_sample_or_class_weights(x_weight, output_names, weight_type): 

672 """Maps `sample_weight` or `class_weight` to model outputs. 

673 

674 Args: 

675 x_weight: User-provided `sample_weight` or `class_weight` argument. 

676 output_names: List of output names (strings) in the model. 

677 weight_type: A string used purely for exception printing. 

678 

679 Returns: 

680 A list of `sample_weight` or `class_weight` where there are exactly 

681 one element per model output. 

682 

683 Raises: 

684 ValueError: In case of invalid user-provided argument. 

685 """ 

686 if x_weight is None or (isinstance(x_weight, (list, tuple)) and 

687 len(x_weight) == 0): # pylint: disable=g-explicit-length-test 

688 return [None for _ in output_names] 

689 if len(output_names) == 1: 

690 if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1: 

691 return x_weight 

692 if isinstance(x_weight, dict) and output_names[0] in x_weight: 

693 return [x_weight[output_names[0]]] 

694 else: 

695 return [x_weight] 

696 if isinstance(x_weight, (list, tuple)): 

697 if len(x_weight) != len(output_names): 

698 raise ValueError('Provided `' + weight_type + '` was a list of ' + 

699 str(len(x_weight)) + ' elements, but the model has ' + 

700 str(len(output_names)) + ' outputs. ' 

701 'You should provide one `' + weight_type + '`' 

702 'array per model output.') 

703 return x_weight 

704 if isinstance(x_weight, collections.abc.Mapping): 

705 generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names) 

706 x_weights = [] 

707 for name in output_names: 

708 x_weights.append(x_weight.get(name)) 

709 return x_weights 

710 else: 

711 raise TypeError('The model has multiple outputs, so `' + weight_type + '` ' 

712 'should be either a list or a dict. ' 

713 'Provided `' + weight_type + '` type not understood: ' + 

714 str(x_weight)) 

715 

716 

717def standardize_class_weights(class_weight, output_names): 

718 return standardize_sample_or_class_weights(class_weight, output_names, 

719 'class_weight') 

720 

721 

722def standardize_sample_weights(sample_weight, output_names): 

723 return standardize_sample_or_class_weights(sample_weight, output_names, 

724 'sample_weight') 

725 

726 

727def check_array_lengths(inputs, targets, weights=None): 

728 """Does user input validation for numpy arrays. 

729 

730 Args: 

731 inputs: list of Numpy arrays of inputs. 

732 targets: list of Numpy arrays of targets. 

733 weights: list of Numpy arrays of sample weights. 

734 

735 Raises: 

736 ValueError: in case of incorrectly formatted data. 

737 """ 

738 

739 def is_tensor_or_composite_tensor(x): 

740 return tensor_util.is_tf_type(x) or is_composite_or_composite_value(x) 

741 

742 def set_of_lengths(x): 

743 # Returns a set with the variation between 

744 # different shapes, with None => 0 

745 if x is None: 

746 return {} 

747 else: 

748 return set([ 

749 y.shape[0] 

750 for y in x 

751 if y is not None and not is_tensor_or_composite_tensor(y) 

752 ]) 

753 

754 set_x = set_of_lengths(inputs) 

755 set_y = set_of_lengths(targets) 

756 set_w = set_of_lengths(weights) 

757 if len(set_x) > 1: 

758 raise ValueError('All input arrays (x) should have ' 

759 'the same number of samples. Got array shapes: ' + 

760 str([x.shape for x in inputs])) 

761 if len(set_y) > 1: 

762 raise ValueError('All target arrays (y) should have ' 

763 'the same number of samples. Got array shapes: ' + 

764 str([y.shape for y in targets])) 

765 if set_x and set_y and list(set_x)[0] != list(set_y)[0]: 

766 raise ValueError('Input arrays should have ' 

767 'the same number of samples as target arrays. ' 

768 'Found ' + str(list(set_x)[0]) + ' input samples ' 

769 'and ' + str(list(set_y)[0]) + ' target samples.') 

770 if len(set_w) > 1: 

771 raise ValueError('All sample_weight arrays should have ' 

772 'the same number of samples. Got array shapes: ' + 

773 str([w.shape for w in weights])) 

774 if set_y and set_w and list(set_y)[0] != list(set_w)[0]: 

775 raise ValueError('Sample_weight arrays should have ' 

776 'the same number of samples as target arrays. Got ' + 

777 str(list(set_y)[0]) + ' input samples and ' + 

778 str(list(set_w)[0]) + ' target samples.') 

779 

780 

781def check_loss_and_target_compatibility(targets, loss_fns, output_shapes): 

782 """Does validation on the compatibility of targets and loss functions. 

783 

784 This helps prevent users from using loss functions incorrectly. This check 

785 is purely for UX purposes. 

786 

787 Args: 

788 targets: list of Numpy arrays of targets. 

789 loss_fns: list of loss functions. 

790 output_shapes: list of shapes of model outputs. 

791 

792 Raises: 

793 ValueError: if a loss function or target array 

794 is incompatible with an output. 

795 """ 

796 key_loss_fns = { 

797 losses.mean_squared_error, losses.binary_crossentropy, 

798 losses.categorical_crossentropy 

799 } 

800 key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy, 

801 losses.CategoricalCrossentropy) 

802 for y, loss, shape in zip(targets, loss_fns, output_shapes): 

803 if y is None or loss is None or tensor_util.is_tf_type(y): 

804 continue 

805 if losses.is_categorical_crossentropy(loss): 

806 if y.shape[-1] == 1: 

807 raise ValueError('You are passing a target array of shape ' + 

808 str(y.shape) + 

809 ' while using as loss `categorical_crossentropy`. ' 

810 '`categorical_crossentropy` expects ' 

811 'targets to be binary matrices (1s and 0s) ' 

812 'of shape (samples, classes). ' 

813 'If your targets are integer classes, ' 

814 'you can convert them to the expected format via:\n' 

815 '```\n' 

816 'from keras.utils import to_categorical\n' 

817 'y_binary = to_categorical(y_int)\n' 

818 '```\n' 

819 '\n' 

820 'Alternatively, you can use the loss function ' 

821 '`sparse_categorical_crossentropy` instead, ' 

822 'which does expect integer targets.') 

823 

824 is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper) 

825 if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and 

826 (loss.fn in key_loss_fns))): 

827 for target_dim, out_dim in zip(y.shape[1:], shape[1:]): 

828 if out_dim is not None and target_dim != out_dim: 

829 loss_name = loss.name 

830 if loss_name is None: 

831 loss_type = loss.fn if is_loss_wrapper else type(loss) 

832 loss_name = loss_type.__name__ 

833 raise ValueError('A target array with shape ' + str(y.shape) + 

834 ' was passed for an output of shape ' + str(shape) + 

835 ' while using as loss `' + loss_name + '`. ' 

836 'This loss expects targets to have the same shape ' 

837 'as the output.') 

838 

839 

840def collect_per_output_metric_info(metrics, 

841 output_names, 

842 output_shapes, 

843 loss_fns, 

844 from_serialized=False, 

845 is_weighted=False): 

846 """Maps metric names and functions to model outputs. 

847 

848 Args: 

849 metrics: a list or a list of lists or a dict of metric functions. 

850 output_names: a list of the names (strings) of model outputs. 

851 output_shapes: a list of the shapes (strings) of model outputs. 

852 loss_fns: a list of the loss functions corresponding to the model outputs. 

853 from_serialized: whether the model the metrics are being sourced from is 

854 being initialized from a serialized format. 

855 is_weighted: Boolean indicating whether the given metrics are weighted. 

856 

857 Returns: 

858 A list (one entry per model output) of dicts. 

859 For instance, if the model has 2 outputs, and for the first output 

860 we want to compute "binary_accuracy" and "binary_crossentropy", 

861 and just "binary_accuracy" for the second output, 

862 the list would look like: `[{ 

863 'acc': binary_accuracy(), 

864 'ce': binary_crossentropy(), 

865 }, { 

866 'acc': binary_accuracy(), 

867 }]` 

868 

869 Raises: 

870 TypeError: if an incorrect type is passed for the `metrics` argument. 

871 """ 

872 if not metrics: 

873 return [{} for _ in output_names] 

874 

875 if isinstance(metrics, list): 

876 any_sub_list = any(isinstance(m, list) for m in metrics) 

877 if any_sub_list: 

878 if len(metrics) != len(output_names): 

879 raise ValueError('When passing a list of lists as `metrics`, ' 

880 'it should have one entry per model output. ' 

881 'The model has ' + str(len(output_names)) + 

882 ' outputs, but you passed metrics=' + str(metrics)) 

883 # User has provided a list of len = len(outputs). 

884 nested_metrics = [generic_utils.to_list(m) for m in metrics] 

885 else: 

886 # If it is a single list we then apply all metrics to all outputs. 

887 if len(output_names) > 1: 

888 nested_metrics = [] 

889 for _ in output_names: 

890 nested_metrics.append( 

891 [metrics_module.clone_metric(m) for m in metrics]) 

892 else: 

893 nested_metrics = [metrics] 

894 elif isinstance(metrics, collections.abc.Mapping): 

895 generic_utils.check_for_unexpected_keys('metrics', metrics, output_names) 

896 nested_metrics = [] 

897 for name in output_names: 

898 output_metrics = generic_utils.to_list(metrics.get(name, [])) 

899 nested_metrics.append(output_metrics) 

900 else: 

901 raise TypeError('Type of `metrics` argument not understood. ' 

902 'Expected a list or dictionary, found: ' + str(metrics)) 

903 

904 per_output_metrics = [] 

905 for i, metrics in enumerate(nested_metrics): 

906 metrics_dict = collections.OrderedDict() 

907 for metric in metrics: 

908 metric_name = get_metric_name(metric, is_weighted) 

909 metric_fn = get_metric_function( 

910 metric, output_shape=output_shapes[i], loss_fn=loss_fns[i]) 

911 metric_fn._from_serialized = from_serialized # pylint: disable=protected-access 

912 

913 # If the metric function is not stateful, we create a stateful version. 

914 if not isinstance(metric_fn, metrics_module.Metric): 

915 metric_fn = metrics_module.MeanMetricWrapper( 

916 metric_fn, name=metric_name) 

917 # If the metric is being revived from something stateless, such as a 

918 # string (e.g. "accuracy"), we may need to later reapply transformations 

919 # such as renaming. 

920 metric_fn._from_serialized = False # pylint: disable=protected-access 

921 metrics_dict[metric_name] = metric_fn 

922 per_output_metrics.append(metrics_dict) 

923 

924 return per_output_metrics 

925 

926 

927def batch_shuffle(index_array, batch_size): 

928 """Shuffles an array in a batch-wise fashion. 

929 

930 Useful for shuffling HDF5 arrays 

931 (where one cannot access arbitrary indices). 

932 

933 Args: 

934 index_array: array of indices to be shuffled. 

935 batch_size: integer. 

936 

937 Returns: 

938 The `index_array` array, shuffled in a batch-wise fashion. 

939 """ 

940 batch_count = int(len(index_array) / batch_size) 

941 # to reshape we need to be cleanly divisible by batch size 

942 # we stash extra items and reappend them after shuffling 

943 last_batch = index_array[batch_count * batch_size:] 

944 index_array = index_array[:batch_count * batch_size] 

945 index_array = index_array.reshape((batch_count, batch_size)) 

946 np.random.shuffle(index_array) 

947 index_array = index_array.flatten() 

948 return np.append(index_array, last_batch) 

949 

950 

951def standardize_weights(y, 

952 sample_weight=None, 

953 class_weight=None, 

954 sample_weight_mode=None): 

955 """Performs sample weight validation and standardization. 

956 

957 Everything gets normalized to a single sample-wise (or timestep-wise) 

958 weight array. If both `sample_weight` and `class_weight` are provided, 

959 the weights are multiplied. 

960 

961 Args: 

962 y: Numpy array or Tensor of model targets to be weighted. 

963 sample_weight: User-provided `sample_weight` argument. 

964 class_weight: User-provided `class_weight` argument. 

965 sample_weight_mode: One of `None` or `"temporal"`. `"temporal"` indicated 

966 that we expect 2D weight data that will be applied to the last 2 

967 dimensions of the targets (i.e. we are weighting timesteps, not 

968 samples). 

969 

970 Returns: 

971 A numpy array of target weights, one entry per sample to weight. 

972 

973 Raises: 

974 ValueError: In case of invalid user-provided arguments. 

975 """ 

976 # Iterator may return sample_weight as 1-tuple 

977 if isinstance(sample_weight, tuple): 

978 sample_weight = sample_weight[0] 

979 if sample_weight_mode is not None and sample_weight_mode != 'samplewise': 

980 if sample_weight_mode != 'temporal': 

981 raise ValueError('"sample_weight_mode ' 

982 'should be None or "temporal". ' 

983 'Found: ' + str(sample_weight_mode)) 

984 if len(y.shape) < 3: 

985 raise ValueError('Found a sample_weight array for ' 

986 'an input with shape ' + str(y.shape) + '. ' 

987 'Timestep-wise sample weighting (use of ' 

988 'sample_weight_mode="temporal") is restricted to ' 

989 'outputs that are at least 3D, i.e. that have ' 

990 'a time dimension.') 

991 if sample_weight is not None and len(sample_weight.shape) != 2: 

992 raise ValueError('Found a sample_weight array with shape ' + 

993 str(sample_weight.shape) + '. ' 

994 'In order to use timestep-wise sample weighting, ' 

995 'you should pass a 2D sample_weight array.') 

996 else: 

997 if sample_weight is not None and len(sample_weight.shape) != 1: 

998 raise ValueError( 

999 'Found a sample_weight array with shape {}. In order to ' 

1000 'use timestep-wise sample weights, you should specify ' 

1001 'sample_weight_mode="temporal" in compile(); founssd "{}" ' 

1002 'instead. If you just mean to use sample-wise weights, ' 

1003 'make sure your sample_weight array is 1D.'.format( 

1004 sample_weight.shape, sample_weight_mode)) 

1005 

1006 if sample_weight is not None: 

1007 if len(sample_weight.shape) > len(y.shape): 

1008 raise ValueError('Found a sample_weight with shape' + 

1009 str(sample_weight.shape) + '.' 

1010 'Expected sample_weight with rank ' 

1011 'less than or equal to ' + str(len(y.shape))) 

1012 

1013 if (not tensor_util.is_tf_type(sample_weight) and 

1014 y.shape[:sample_weight.ndim] != sample_weight.shape): 

1015 raise ValueError('Found a sample_weight array with shape ' + 

1016 str(sample_weight.shape) + ' for an input with shape ' + 

1017 str(y.shape) + '. ' 

1018 'sample_weight cannot be broadcast.') 

1019 

1020 # Class weights applied per-sample. 

1021 class_sample_weight = None 

1022 if isinstance(class_weight, dict): 

1023 if len(y.shape) > 2: 

1024 raise ValueError('`class_weight` not supported for ' 

1025 '3+ dimensional targets.') 

1026 

1027 if tensor_util.is_tf_type(y): 

1028 # Few classes are expected, so densifying is reasonable. 

1029 keys = np.array(sorted(class_weight.keys())) 

1030 values = np.array([class_weight[i] for i in keys]) 

1031 weight_vector = np.zeros(np.max(keys) + 1) 

1032 weight_vector[:] = np.nan 

1033 weight_vector[keys] = values 

1034 

1035 y_classes = smart_cond.smart_cond( 

1036 len(y.shape.as_list()) == 2 and backend.shape(y)[1] > 1, 

1037 lambda: backend.argmax(y, axis=1), 

1038 lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64)) 

1039 class_sample_weight = array_ops.gather(weight_vector, y_classes) 

1040 gen_array_ops.check_numerics( 

1041 class_sample_weight, 

1042 'Invalid classes or class weights detected. NaN values indicate that ' 

1043 'an appropriate class weight could not be determined.') 

1044 class_sample_weight = math_ops.cast(class_sample_weight, backend.floatx()) 

1045 if sample_weight is not None: 

1046 sample_weight = math_ops.cast( 

1047 tensor_conversion.convert_to_tensor_v2_with_dispatch(sample_weight), 

1048 backend.floatx(), 

1049 ) 

1050 else: 

1051 y_classes = y 

1052 if len(y.shape) == 2: 

1053 if y.shape[1] > 1: 

1054 y_classes = np.argmax(y, axis=1) 

1055 elif y.shape[1] == 1: 

1056 y_classes = np.reshape(y, y.shape[0]) 

1057 

1058 class_sample_weight = np.asarray( 

1059 [class_weight[cls] for cls in y_classes if cls in class_weight]) 

1060 

1061 if len(class_sample_weight) != len(y_classes): 

1062 # subtract the sets to pick all missing classes 

1063 existing_classes = set(y_classes) 

1064 existing_class_weight = set(class_weight.keys()) 

1065 raise ValueError( 

1066 '`class_weight` must contain all classes in the data.' 

1067 ' The classes %s exist in the data but not in ' 

1068 '`class_weight`.' % (existing_classes - existing_class_weight)) 

1069 

1070 if class_sample_weight is not None and sample_weight is not None: 

1071 # Multiply weights if both are provided. 

1072 return class_sample_weight * sample_weight 

1073 if sample_weight is not None: 

1074 return sample_weight 

1075 if class_sample_weight is not None: 

1076 return class_sample_weight 

1077 return None 

1078 

1079 

1080def has_symbolic_tensors(ls): 

1081 if context.executing_eagerly(): 

1082 return False 

1083 return has_tensors(ls) 

1084 

1085 

1086def has_tensors(ls): 

1087 """Returns true if `ls` contains tensors.""" 

1088 # Note: at some point in time ragged tensors didn't count as tensors, so this 

1089 # returned false for ragged tensors. Making this return true fails some tests 

1090 # which would then require a steps_per_epoch argument. 

1091 if isinstance(ls, (list, tuple)): 

1092 return any( 

1093 tensor_util.is_tf_type(v) and 

1094 not isinstance(v, ragged_tensor.RaggedTensor) for v in ls) 

1095 if isinstance(ls, dict): 

1096 return any( 

1097 tensor_util.is_tf_type(v) and 

1098 not isinstance(v, ragged_tensor.RaggedTensor) 

1099 for _, v in ls.items()) 

1100 return tensor_util.is_tf_type(ls) and not isinstance( 

1101 ls, ragged_tensor.RaggedTensor) 

1102 

1103 

1104def get_metric_name(metric, weighted=False): 

1105 """Returns the name corresponding to the given metric input. 

1106 

1107 Args: 

1108 metric: Metric function name or reference. 

1109 weighted: Boolean indicating if the given metric is weighted. 

1110 

1111 Returns: 

1112 The metric name. 

1113 """ 

1114 if tf2.enabled(): 

1115 # We keep the string that the user has set in compile as the metric name. 

1116 if isinstance(metric, str): 

1117 return metric 

1118 

1119 metric = metrics_module.get(metric) 

1120 return metric.name if hasattr(metric, 'name') else metric.__name__ 

1121 else: 

1122 metric_name_prefix = 'weighted_' if weighted else '' 

1123 if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): 

1124 if metric in ('accuracy', 'acc'): 

1125 suffix = 'acc' 

1126 elif metric in ('crossentropy', 'ce'): 

1127 suffix = 'ce' 

1128 else: 

1129 metric_fn = metrics_module.get(metric) 

1130 # Get metric name as string 

1131 if hasattr(metric_fn, 'name'): 

1132 suffix = metric_fn.name 

1133 else: 

1134 suffix = metric_fn.__name__ 

1135 metric_name = metric_name_prefix + suffix 

1136 return metric_name 

1137 

1138 

1139def get_metric_function(metric, output_shape=None, loss_fn=None): 

1140 """Returns the metric function corresponding to the given metric input. 

1141 

1142 Args: 

1143 metric: Metric function name or reference. 

1144 output_shape: The shape of the output that this metric will be calculated 

1145 for. 

1146 loss_fn: The loss function used. 

1147 

1148 Returns: 

1149 The metric function. 

1150 """ 

1151 if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']: 

1152 return metrics_module.get(metric) 

1153 

1154 is_sparse_categorical_crossentropy = ( 

1155 isinstance(loss_fn, losses.SparseCategoricalCrossentropy) or 

1156 (isinstance(loss_fn, losses.LossFunctionWrapper) and 

1157 loss_fn.fn == losses.sparse_categorical_crossentropy)) 

1158 

1159 is_binary_crossentropy = ( 

1160 isinstance(loss_fn, losses.BinaryCrossentropy) or 

1161 (isinstance(loss_fn, losses.LossFunctionWrapper) and 

1162 loss_fn.fn == losses.binary_crossentropy)) 

1163 

1164 if metric in ['accuracy', 'acc']: 

1165 if output_shape[-1] == 1 or is_binary_crossentropy: 

1166 return metrics_module.binary_accuracy 

1167 elif is_sparse_categorical_crossentropy: 

1168 return metrics_module.sparse_categorical_accuracy 

1169 # If the output_shape[-1] is not 1, then we know output is `categorical`. 

1170 # We assume it is sparse categorical only if loss is explicitly given 

1171 # as sparse categorical crossentropy loss. 

1172 return metrics_module.categorical_accuracy 

1173 else: 

1174 if output_shape[-1] == 1 or is_binary_crossentropy: 

1175 return metrics_module.binary_crossentropy 

1176 elif is_sparse_categorical_crossentropy: 

1177 return metrics_module.sparse_categorical_crossentropy 

1178 return metrics_module.categorical_crossentropy 

1179 

1180 

1181def call_metric_function(metric_fn, 

1182 y_true, 

1183 y_pred=None, 

1184 weights=None, 

1185 mask=None): 

1186 """Invokes metric function and returns the metric result tensor.""" 

1187 if mask is not None: 

1188 mask = math_ops.cast(mask, y_pred.dtype) 

1189 if weights is None: 

1190 # Use mask as sample weight. 

1191 weights = mask 

1192 else: 

1193 # Update dimensions of weights to match with mask. 

1194 weights = math_ops.cast(weights, dtype=y_pred.dtype) 

1195 mask, _, weights = losses_utils.squeeze_or_expand_dimensions( 

1196 mask, sample_weight=weights) 

1197 weights *= mask 

1198 

1199 if y_pred is not None: 

1200 return metric_fn(y_true, y_pred, sample_weight=weights) 

1201 # `Mean` metric only takes a single value. 

1202 return metric_fn(y_true, sample_weight=weights) 

1203 

1204 

1205def get_loss_function(loss): 

1206 """Returns the loss corresponding to the loss input in `compile` API.""" 

1207 if loss is None or isinstance(loss, losses.Loss): 

1208 return loss 

1209 

1210 if tf_inspect.isclass(loss) and issubclass(loss, losses.Loss): 

1211 # It is not safe to assume that the loss takes no constructor arguments. 

1212 raise ValueError( 

1213 'Received uninstantiated Loss class: {}\nPlease call loss ""classes ' 

1214 'before passing them to Model.compile.'.format(loss)) 

1215 

1216 # Deserialize loss configuration, if needed. 

1217 if isinstance(loss, collections.abc.Mapping): 

1218 loss = losses.get(loss) 

1219 

1220 # Custom callable class. 

1221 if callable(loss) and not hasattr(loss, '__name__'): 

1222 return loss 

1223 

1224 # Wrap loss function with signature `(y_true, y_pred, **kwargs)` 

1225 # in `LossFunctionWrapper` class. 

1226 loss_fn = losses.get(loss) 

1227 

1228 # For losses which are given as strings/functions in the compile API, 

1229 # we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE` 

1230 # (both in distribution strategy context and otherwise). 

1231 return losses.LossFunctionWrapper( 

1232 loss_fn, 

1233 name=loss_fn.__name__, 

1234 reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE) 

1235 

1236 

1237def validate_dataset_input(x, y, sample_weight, validation_split=None): 

1238 """Validates user input arguments when a dataset iterator is passed. 

1239 

1240 Args: 

1241 x: Input data. A `tf.data` dataset or iterator. 

1242 y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s). 

1243 Expected to be `None` when `x` is a dataset iterator. 

1244 sample_weight: An optional sample-weight array passed by the user to weight 

1245 the importance of each sample in `x`. Expected to be `None` when `x` is a 

1246 dataset iterator 

1247 validation_split: Float between 0 and 1. Fraction of the training data to be 

1248 used as validation data. Expected to be `None` when `x` is a dataset 

1249 iterator. 

1250 

1251 Raises: 

1252 ValueError: if argument `y` or `sample_weight` or `validation_split` are 

1253 provided by user. 

1254 """ 

1255 if y is not None: 

1256 raise ValueError('You passed a dataset or dataset iterator (%s) as ' 

1257 'input `x` to your model. In that case, you should ' 

1258 'not specify a target (`y`) argument, since the dataset ' 

1259 'or dataset iterator generates both input data and ' 

1260 'target data. ' 

1261 'Received: %s' % (x, y)) 

1262 if sample_weight is not None: 

1263 raise ValueError('`sample_weight` argument is not supported when input ' 

1264 '`x` is a dataset or a dataset iterator. Instead, you' 

1265 'can provide sample_weight as the third element of your' 

1266 'dataset, i.e. (inputs, targets, sample_weight). ' 

1267 'Received: x=%s, sample_weight=%s' % (x, sample_weight)) 

1268 if validation_split is not None and validation_split != 0.0: 

1269 raise ValueError( 

1270 '`validation_split` argument is not supported when ' 

1271 'input `x` is a dataset or a dataset iterator. ' 

1272 'Received: x=%s, validation_split=%f' % (x, validation_split)) 

1273 

1274 

1275def validate_input_types(inp, orig_inp, allow_dict=True, field_name='inputs'): 

1276 """Helper function to validate either inputs or targets.""" 

1277 if isinstance(inp, (list, tuple)): 

1278 if not all(isinstance(v, np.ndarray) or 

1279 tensor_util.is_tf_type(v) for v in inp): 

1280 raise ValueError( 

1281 'Please provide as model inputs either a single array or a list of ' 

1282 'arrays. You passed: {}={}'.format(field_name, str(orig_inp))) 

1283 elif isinstance(inp, dict): 

1284 if not allow_dict: 

1285 raise ValueError( 

1286 'You cannot pass a dictionary as model {}.'.format(field_name)) 

1287 elif not isinstance(inp, np.ndarray) and not tensor_util.is_tf_type(inp): 

1288 raise ValueError( 

1289 'Please provide as model inputs either a single array or a list of ' 

1290 'arrays. You passed: {}={}'.format(field_name, orig_inp)) 

1291 

1292 

1293def check_generator_arguments(y=None, sample_weight=None, 

1294 validation_split=None): 

1295 """Validates arguments passed when using a generator.""" 

1296 if y is not None: 

1297 raise ValueError('`y` argument is not supported when data is' 

1298 'a generator or Sequence instance. Instead pass targets' 

1299 ' as the second element of the generator.') 

1300 if sample_weight is not None: 

1301 raise ValueError('`sample_weight` argument is not supported when data is' 

1302 'a generator or Sequence instance. Instead pass sample' 

1303 ' weights as the third element of the generator.') 

1304 if validation_split: 

1305 raise ValueError('If your data is in the form of a Python generator, ' 

1306 'you cannot use `validation_split`.') 

1307 

1308 

1309def check_steps_argument(input_data, steps, steps_name): 

1310 """Validates `steps` argument based on input data's type. 

1311 

1312 The cases when `steps` value must be provided are when 

1313 1. input data passed is an iterator. 

1314 2. model was built on top of symbolic tensors, input data is not 

1315 required and is `None`. 

1316 3. input data passed is a symbolic tensor. 

1317 

1318 Args: 

1319 input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or 

1320 tf.data.Dataset iterator or `None`. 

1321 steps: Integer or `None`. Total number of steps (batches of samples) to 

1322 execute. 

1323 steps_name: The public API's parameter name for `steps`. 

1324 

1325 Returns: 

1326 boolean, True if `steps` argument is required, else False. 

1327 

1328 Raises: 

1329 ValueError: if `steps` argument is required for given input data type 

1330 but not provided. 

1331 """ 

1332 is_x_iterator = isinstance( 

1333 input_data, (iterator_ops.Iterator, iterator_ops.IteratorBase)) 

1334 if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or 

1335 (isinstance(input_data, list) and not input_data)): 

1336 if steps is None: 

1337 input_type_str = 'a Dataset iterator' if is_x_iterator else 'data tensors' 

1338 raise ValueError('When using {input_type} as input to a model, you should' 

1339 ' specify the `{steps_name}` argument.'.format( 

1340 input_type=input_type_str, steps_name=steps_name)) 

1341 return True 

1342 

1343 if isinstance(input_data, (data_types.DatasetV1, data_types.DatasetV2)): 

1344 return True 

1345 

1346 if steps is not None: 

1347 list_types = (np.ndarray, list, tuple) 

1348 if (isinstance(input_data, list_types) or 

1349 (isinstance(input_data, dict) and 

1350 any(isinstance(v, list_types) for v in input_data.values()))): 

1351 logging.warning('When passing input data as arrays, do not specify ' 

1352 '`steps_per_epoch`/`steps` argument. ' 

1353 'Please use `batch_size` instead.') 

1354 return False 

1355 

1356 

1357def cast_single_tensor(x, dtype=None): 

1358 if isinstance(x, np.ndarray): 

1359 x = tensor_conversion.convert_to_tensor_v2_with_dispatch(x) 

1360 dtype = dtype or backend.floatx() 

1361 if x.dtype.is_floating: 

1362 return math_ops.cast(x, dtype=dtype) 

1363 return x 

1364 

1365 

1366def cast_if_floating_dtype_and_mismatch(targets, outputs): 

1367 """Returns target data tensors using correct datatype. 

1368 

1369 Checks that each target and output pair are the same datatype. If not, casts 

1370 the target to the output's datatype. 

1371 

1372 Args: 

1373 targets: tensor or list of targets. 

1374 outputs: tensor or list of outputs. 

1375 

1376 Returns: 

1377 Targets in appropriate datatype. 

1378 """ 

1379 if tensor_util.is_tf_type(targets): 

1380 # There is one target, so output[0] should be the only output. 

1381 return cast_single_tensor(targets, dtype=outputs[0].dtype) 

1382 new_targets = [] 

1383 for target, out in zip(targets, outputs): 

1384 if isinstance(target, np.ndarray): 

1385 target = tensor_conversion.convert_to_tensor_v2_with_dispatch(target) 

1386 if target.dtype != out.dtype: 

1387 new_targets.append(cast_single_tensor(target, dtype=out.dtype)) 

1388 else: 

1389 new_targets.append(target) 

1390 return new_targets 

1391 

1392 

1393def cast_if_floating_dtype(x, dtype=None): 

1394 """Casts the given data tensors to the default floating point type. 

1395 

1396 Casts only if the input is already a floating point type. 

1397 Args: 

1398 x: tensor or list/tuple of tensors. 

1399 dtype: The dtype to which Tensors should be cast. 

1400 

1401 Returns: 

1402 Converted input. 

1403 """ 

1404 return nest.map_structure(functools.partial(cast_single_tensor, dtype=dtype), 

1405 x) 

1406 

1407 

1408def cast_to_model_input_dtypes(x, model): 

1409 """Casts the given data tensors to the dtypes of the model inputs. 

1410 

1411 Args: 

1412 x: tensor or list/tuple of tensors. 

1413 model: The model. 

1414 

1415 Returns: 

1416 Converted input. Each tensor is casted to the corresponding input in 

1417 `model.inputs`. 

1418 """ 

1419 input_dtypes = nest.map_structure(lambda t: t.dtype, model.inputs) 

1420 return nest.map_structure(math_ops.cast, x, input_dtypes) 

1421 

1422 

1423def prepare_sample_weight_modes(training_endpoints, sample_weight_mode): 

1424 """Prepares sample weight modes for the model. 

1425 

1426 Args: 

1427 training_endpoints: List of model _TrainingEndpoints. 

1428 sample_weight_mode: sample weight mode user input passed from compile API. 

1429 

1430 Raises: 

1431 ValueError: In case of invalid `sample_weight_mode` input. 

1432 """ 

1433 

1434 if isinstance(sample_weight_mode, collections.abc.Mapping): 

1435 generic_utils.check_for_unexpected_keys( 

1436 'sample_weight_mode', sample_weight_mode, 

1437 [e.output_name for e in training_endpoints]) 

1438 

1439 for end_point in training_endpoints: 

1440 if not end_point.should_skip_target_weights(): 

1441 if end_point.output_name not in sample_weight_mode: 

1442 raise ValueError('Output ' + end_point.output_name + 

1443 'missing from `_sample_weight_modes` dictionary') 

1444 else: 

1445 end_point.sample_weight_mode = sample_weight_mode.get( 

1446 end_point.output_name) 

1447 elif isinstance(sample_weight_mode, (list, tuple)): 

1448 if len(sample_weight_mode) != len(training_endpoints): 

1449 raise ValueError('When passing a list as sample_weight_mode, ' 

1450 'it should have one entry per model output. ' 

1451 'The model has ' + str(len(training_endpoints)) + 

1452 ' outputs, but you passed ' + 

1453 str(len(sample_weight_mode)) + '_sample_weight_modes.') 

1454 for mode, endpoint in zip(sample_weight_mode, training_endpoints): 

1455 if not endpoint.should_skip_target_weights(): 

1456 endpoint.sample_weight_mode = mode 

1457 else: 

1458 for endpoint in training_endpoints: 

1459 if not endpoint.should_skip_target_weights(): 

1460 endpoint.sample_weight_mode = sample_weight_mode 

1461 

1462 

1463def prepare_loss_functions(loss, output_names): 

1464 """Converts loss to a list of loss functions. 

1465 

1466 Args: 

1467 loss: String (name of objective function), objective function or 

1468 `tf.losses.Loss` instance. See `tf.losses`. If the model has multiple 

1469 outputs, you can use a different loss on each output by passing a 

1470 dictionary or a list of losses. The loss value that will be minimized by 

1471 the model will then be the sum of all individual losses. 

1472 output_names: List of model output names. 

1473 

1474 Returns: 

1475 A list of loss objective functions. 

1476 

1477 Raises: 

1478 ValueError: If loss is a dict with keys not in model output names, 

1479 or if loss is a list with len not equal to model outputs. 

1480 """ 

1481 if isinstance(loss, collections.abc.Mapping): 

1482 generic_utils.check_for_unexpected_keys('loss', loss, output_names) 

1483 loss_functions = [] 

1484 for name in output_names: 

1485 if name not in loss: 

1486 logging.warning( 

1487 'Output {0} missing from loss dictionary. We assume ' 

1488 'this was done on purpose. The fit and evaluate APIs will not be ' 

1489 'expecting any data to be passed to {0}.'.format(name)) 

1490 loss_functions.append(get_loss_function(loss.get(name, None))) 

1491 elif isinstance(loss, str): 

1492 loss_functions = [get_loss_function(loss) for _ in output_names] 

1493 elif isinstance(loss, collections.abc.Sequence): 

1494 if len(loss) != len(output_names): 

1495 raise ValueError('When passing a list as loss, it should have one entry ' 

1496 'per model outputs. The model has {} outputs, but you ' 

1497 'passed loss={}'.format(len(output_names), loss)) 

1498 loss_functions = nest.map_structure(get_loss_function, loss) 

1499 else: 

1500 loss_functions = [get_loss_function(loss) for _ in range(len(output_names))] 

1501 

1502 return loss_functions 

1503 

1504 

1505def prepare_loss_weights(training_endpoints, loss_weights=None): 

1506 """Converts loss weights to a list of loss weights. 

1507 

1508 The result loss weights will be populated on the training endpoint. 

1509 

1510 Args: 

1511 training_endpoints: List of model training endpoints. 

1512 loss_weights: Optional list or dictionary specifying scalar coefficients 

1513 (Python floats) to weight the loss contributions of different model 

1514 outputs. The loss value that will be minimized by the model will then be 

1515 the *weighted sum* of all individual losses, weighted by the 

1516 `loss_weights` coefficients. If a list, it is expected to have a 1:1 

1517 mapping to the model's outputs. If a dict, it is expected to map 

1518 output names (strings) to scalar coefficients. 

1519 

1520 Raises: 

1521 ValueError: If loss weight is a dict with key not in model output names, 

1522 or if loss is a list with len not equal to model outputs. 

1523 """ 

1524 if loss_weights is None: 

1525 for e in training_endpoints: 

1526 e.loss_weight = 1. 

1527 elif isinstance(loss_weights, collections.abc.Mapping): 

1528 generic_utils.check_for_unexpected_keys( 

1529 'loss_weights', loss_weights, 

1530 [e.output_name for e in training_endpoints]) 

1531 for e in training_endpoints: 

1532 e.loss_weight = loss_weights.get(e.output_name, 1.) 

1533 elif isinstance(loss_weights, list): 

1534 if len(loss_weights) != len(training_endpoints): 

1535 raise ValueError('When passing a list as loss_weights, ' 

1536 'it should have one entry per model output. ' 

1537 'The model has ' + str(len(training_endpoints)) + 

1538 ' outputs, but you passed loss_weights=' + 

1539 str(loss_weights)) 

1540 for w, e in zip(loss_weights, training_endpoints): 

1541 e.loss_weight = w 

1542 else: 

1543 raise TypeError('Could not interpret loss_weights argument: ' + 

1544 str(loss_weights) + ' - expected a list of dicts.') 

1545 

1546 

1547# TODO(rohanj): This is a hack to get around not depending on feature_column and 

1548# create a cyclical dependency. Figure out a cleaner solution 

1549def is_feature_layer(layer): 

1550 """Returns whether `layer` is a FeatureLayer or not.""" 

1551 return getattr(layer, '_is_feature_layer', False) 

1552 

1553 

1554def is_eager_dataset_or_iterator(data): 

1555 return context.executing_eagerly() and isinstance( 

1556 data, (data_types.DatasetV1, data_types.DatasetV2, 

1557 iterator_ops.IteratorBase)) 

1558 

1559 

1560# pylint: disable=protected-access 

1561def get_dataset_graph_def(dataset): 

1562 if context.executing_eagerly(): 

1563 graph_def_str = dataset._as_serialized_graph().numpy() 

1564 else: 

1565 graph_def_str = backend.get_value(dataset._as_serialized_graph()) 

1566 return graph_pb2.GraphDef().FromString(graph_def_str) 

1567 

1568 

1569def verify_dataset_shuffled(x): 

1570 """Verifies that the dataset is shuffled. 

1571 

1572 Args: 

1573 x: Dataset passed as an input to the model. 

1574 

1575 Returns: 

1576 boolean, whether the input dataset is shuffled or not. 

1577 """ 

1578 assert isinstance(x, data_types.DatasetV2) 

1579 graph_def = get_dataset_graph_def(x) 

1580 for node in graph_def.node: 

1581 if node.op.startswith('ShuffleDataset'): 

1582 return True 

1583 # Also check graph_def.library.function for ds.interleave or ds.flat_map 

1584 for function in graph_def.library.function: 

1585 for node in function.node_def: 

1586 if node.op.startswith('ShuffleDataset'): 

1587 return True 

1588 logging.warning('Expected a shuffled dataset but input dataset `x` is ' 

1589 'not shuffled. Please invoke `shuffle()` on input dataset.') 

1590 return False 

1591 

1592 

1593def is_dataset_or_iterator(data): 

1594 return isinstance(data, (data_types.DatasetV1, data_types.DatasetV2, 

1595 iterator_ops.Iterator, iterator_ops.IteratorBase)) 

1596 

1597 

1598def get_iterator(dataset): 

1599 """Create and initialize an iterator from a dataset.""" 

1600 if context.executing_eagerly(): 

1601 iterator = dataset_ops.make_one_shot_iterator(dataset) 

1602 else: 

1603 iterator = dataset_ops.make_initializable_iterator(dataset) 

1604 initialize_iterator(iterator) 

1605 return iterator 

1606 

1607 

1608def initialize_iterator(iterator): 

1609 if not context.executing_eagerly(): 

1610 init_op = iterator.initializer 

1611 backend.get_session((init_op,)).run(init_op) 

1612 

1613 

1614def extract_tensors_from_dataset(dataset): 

1615 """Extract a tuple of tensors `inputs, targets, sample_weight` from a dataset. 

1616 

1617 Args: 

1618 dataset: Dataset instance. 

1619 

1620 Returns: 

1621 Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None. 

1622 """ 

1623 iterator = get_iterator(dataset) 

1624 inputs, targets, sample_weight = unpack_iterator_input(iterator) 

1625 return inputs, targets, sample_weight 

1626 

1627 

1628def unpack_iterator_input(iterator): 

1629 """Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`. 

1630 

1631 Args: 

1632 iterator: Instance of a dataset iterator. 

1633 

1634 Returns: 

1635 Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None. 

1636 """ 

1637 try: 

1638 next_element = iterator.get_next() 

1639 except errors.OutOfRangeError: 

1640 raise RuntimeError('Your dataset iterator ran out of data; ' 

1641 'Make sure that your dataset can generate ' 

1642 'required number of samples.') 

1643 

1644 if isinstance(next_element, (list, tuple)): 

1645 if len(next_element) not in [2, 3]: 

1646 raise ValueError( 

1647 'Please provide model inputs as a list or tuple of 2 or 3 ' 

1648 'elements: (input, target) or (input, target, sample_weights) ' 

1649 'Received %s' % next_element) 

1650 if len(next_element) == 2: 

1651 x, y = next_element 

1652 weights = None 

1653 else: 

1654 x, y, weights = next_element 

1655 else: 

1656 x = next_element 

1657 y = None 

1658 weights = None 

1659 return x, y, weights 

1660 

1661 

1662def infer_steps_for_dataset(model, 

1663 dataset, 

1664 steps, 

1665 epochs=1, 

1666 steps_name='steps'): 

1667 """Infers steps_per_epoch needed to loop through a dataset. 

1668 

1669 Args: 

1670 model: Keras model instance. 

1671 dataset: Input data of type tf.data.Dataset. 

1672 steps: Number of steps to draw from the dataset (may be None if unknown). 

1673 epochs: Number of times to iterate over the dataset. 

1674 steps_name: The string name of the steps argument, either `steps`, 

1675 `validation_steps`, or `steps_per_epoch`. Only used for error message 

1676 formatting. 

1677 

1678 Returns: 

1679 Integer or `None`. Inferred number of steps to loop through the dataset. 

1680 `None` is returned if 1) the size of the dataset is unknown and `steps` was 

1681 not specified, or 2) this is multi-worker training and auto sharding is 

1682 enabled. 

1683 

1684 Raises: 

1685 ValueError: In case of invalid argument values. 

1686 """ 

1687 assert isinstance(dataset, data_types.DatasetV2) 

1688 if (model._in_multi_worker_mode() and 

1689 (dataset.options().experimental_distribute.auto_shard_policy != 

1690 options_lib.AutoShardPolicy.OFF)): 

1691 # If the dataset would be auto-sharded, we should not infer a local 

1692 # steps_per_epoch due to the possible inbalanced sharding between workers. 

1693 return None 

1694 

1695 size = backend.get_value(cardinality.cardinality(dataset)) 

1696 if size == cardinality.INFINITE and steps is None: 

1697 raise ValueError('When passing an infinitely repeating dataset, you ' 

1698 'must specify the `%s` argument.' % (steps_name,)) 

1699 if size >= 0: 

1700 if steps is not None and steps * epochs > size: 

1701 if epochs > 1: 

1702 raise ValueError('The dataset you passed contains %s batches, but you ' 

1703 'passed `epochs=%s` and `%s=%s`, which is a total of ' 

1704 '%s steps. We cannot draw that many steps from this ' 

1705 'dataset. We suggest to set `%s=%s`.' % 

1706 (size, epochs, steps_name, steps, steps * epochs, 

1707 steps_name, size // epochs)) 

1708 else: 

1709 raise ValueError('The dataset you passed contains %s batches, but you ' 

1710 'passed `%s=%s`. We cannot draw that many steps from ' 

1711 'this dataset. We suggest to set `%s=%s`.' % 

1712 (size, steps_name, steps, steps_name, size)) 

1713 if steps is None: 

1714 if size >= 0: 

1715 return size 

1716 return None 

1717 return steps 

1718 

1719 

1720class ModelInputs(object): 

1721 """Encapsulates model inputs. 

1722 

1723 Allows for transforming model inputs while keeping the same structure. 

1724 """ 

1725 

1726 def __init__(self, inputs): 

1727 self._inputs = inputs 

1728 self._is_dict = isinstance(self._inputs, dict) 

1729 self._is_single_input = not isinstance(self._inputs, (list, tuple, dict)) 

1730 

1731 self._flattened_inputs = [] 

1732 self._input_names = [] 

1733 

1734 if self._is_dict: 

1735 for k in sorted(self._inputs.keys()): 

1736 self._flattened_inputs.append(self._inputs[k]) 

1737 self._input_names.append(k) 

1738 else: 

1739 self._flattened_inputs = nest.flatten(self._inputs) 

1740 self._input_names = [ 

1741 'input_%d' % (i + 1) for i in range(len(self._flattened_inputs)) 

1742 ] 

1743 

1744 def get_input_names(self): 

1745 """Returns keys to name inputs by. 

1746 

1747 In case inputs provided were a list, tuple or single entry, we make up a 

1748 key 'input_%d'. For dictionary case, we return a sorted list of keys. 

1749 """ 

1750 return self._input_names 

1751 

1752 def get_symbolic_inputs(self, return_single_as_list=False): 

1753 """Returns inputs to be set as self.inputs for a model.""" 

1754 # TODO(karmel): There is a side-effect here where what you get 

1755 # with as_list and as_dict depends on whether you have called this 

1756 # method first, since it modifies in place. 

1757 for i, (k, v) in enumerate(zip(self._input_names, self._flattened_inputs)): 

1758 if isinstance(v, (list, float, int)): 

1759 v = np.asarray(v) 

1760 if v.ndim == 1: 

1761 v = np.expand_dims(v, 1) 

1762 

1763 if isinstance(v, np.ndarray): 

1764 # We fix the placeholder shape except the batch size. 

1765 # This is suboptimal, but it is the best we can do with the info 

1766 # we have. The user should call `model._set_inputs(placeholders)` 

1767 # to specify custom placeholders if the need arises. 

1768 shape = (None,) + tuple(v.shape[1:]) 

1769 if shape == (None,): 

1770 shape = (None, 1) 

1771 dtype = dtypes.as_dtype(v.dtype) 

1772 if dtype.is_floating: 

1773 dtype = backend.floatx() 

1774 v = backend.placeholder(shape=shape, name=k, dtype=dtype) 

1775 elif isinstance(v, tensor_spec.TensorSpec): 

1776 shape = (None,) + tuple(v.shape.as_list()[1:]) 

1777 if shape == (None,): 

1778 shape = (None, 1) 

1779 v = backend.placeholder(shape=shape, name=k, dtype=v.dtype) 

1780 

1781 self._flattened_inputs[i] = v 

1782 

1783 if self._is_dict: 

1784 return dict(zip(self._input_names, self._flattened_inputs)) 

1785 if self._is_single_input and not return_single_as_list: 

1786 return self._flattened_inputs[0] 

1787 return self._flattened_inputs 

1788 

1789 def as_dict(self): 

1790 """An iterable over a dictionary version of inputs.""" 

1791 for k, v in zip(self._input_names, self._flattened_inputs): 

1792 yield k, v 

1793 

1794 def as_list(self): 

1795 """Returning the inputs as a list.""" 

1796 return self._flattened_inputs 

1797 

1798 

1799# Allow use of methods not exposed to the user. 

1800# pylint: disable=protected-access 

1801 

1802 

1803# pylint: enable=protected-access 

1804 

1805 

1806def generic_output_names(outputs_list): 

1807 return ['output_%d' % (i + 1) for i in range(len(outputs_list))] 

1808 

1809 

1810def should_run_validation(validation_freq, epoch): 

1811 """Checks if validation should be run this epoch. 

1812 

1813 Args: 

1814 validation_freq: Integer or list. If an integer, specifies how many training 

1815 epochs to run before a new validation run is performed. If a list, 

1816 specifies the epochs on which to run validation. 

1817 epoch: Integer, the number of the training epoch just completed. 

1818 

1819 Returns: 

1820 Bool, True if validation should be run. 

1821 

1822 Raises: 

1823 ValueError: if `validation_freq` is an Integer and less than 1, or if 

1824 it is neither an Integer nor a Sequence. 

1825 """ 

1826 # `epoch` is 0-indexed internally but 1-indexed in the public API. 

1827 one_indexed_epoch = epoch + 1 

1828 

1829 if isinstance(validation_freq, int): 

1830 if validation_freq < 1: 

1831 raise ValueError('`validation_freq` can not be less than 1.') 

1832 return one_indexed_epoch % validation_freq == 0 

1833 

1834 if not isinstance(validation_freq, collections.abc.Container): 

1835 raise ValueError('`validation_freq` must be an Integer or ' 

1836 '`collections.abc.Container` (e.g. list, tuple, etc.)') 

1837 return one_indexed_epoch in validation_freq 

1838 

1839 

1840def split_training_and_validation_data(x, y, sample_weights, validation_split): 

1841 """Split input data into train/eval section based on validation_split.""" 

1842 if has_symbolic_tensors(x): 

1843 raise ValueError('If your data is in the form of symbolic tensors, ' 

1844 'you cannot use `validation_split`.') 

1845 if hasattr(x[0], 'shape'): 

1846 split_at = int(x[0].shape[0] * (1. - validation_split)) 

1847 else: 

1848 split_at = int(len(x[0]) * (1. - validation_split)) 

1849 x, val_x = (generic_utils.slice_arrays(x, 0, split_at), 

1850 generic_utils.slice_arrays(x, split_at)) 

1851 y, val_y = (generic_utils.slice_arrays(y, 0, split_at), 

1852 generic_utils.slice_arrays(y, split_at)) 

1853 if sample_weights: 

1854 sample_weights, val_sample_weights = ( 

1855 generic_utils.slice_arrays(sample_weights, 0, split_at), 

1856 generic_utils.slice_arrays(sample_weights, split_at), 

1857 ) 

1858 else: 

1859 val_sample_weights = None 

1860 return x, y, sample_weights, val_x, val_y, val_sample_weights 

1861 

1862 

1863def unpack_validation_data(validation_data, raise_if_ambiguous=True): 

1864 """Unpack validation data based input type. 

1865 

1866 The validation data is not touched if its dataset or dataset iterator. 

1867 For other type of input (Numpy or tensor), it will be unpacked into tuple of 

1868 3 which is x, y and sample weights. 

1869 

1870 Args: 

1871 validation_data: dataset, dataset iterator, or numpy, tensor tuple. 

1872 raise_if_ambiguous: boolean on whether to fail if validation_data cannot be 

1873 parsed. Otherwise simply return validation_data, None, None and defer the 

1874 decision to the caller. 

1875 

1876 Returns: 

1877 tuple of 3, (x, y, sample_weights) for numpy and tensor input. 

1878 """ 

1879 if (isinstance(validation_data, (iterator_ops.Iterator, 

1880 iterator_ops.IteratorBase, 

1881 data_types.DatasetV2, 

1882 data_utils.Sequence)) 

1883 or not hasattr(validation_data, '__len__')): 

1884 val_x = validation_data 

1885 val_y = None 

1886 val_sample_weight = None 

1887 elif len(validation_data) == 2: 

1888 try: 

1889 val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence 

1890 val_sample_weight = None 

1891 except ValueError: 

1892 val_x, val_y, val_sample_weight = validation_data, None, None 

1893 elif len(validation_data) == 3: 

1894 try: 

1895 val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence 

1896 except ValueError: 

1897 val_x, val_y, val_sample_weight = validation_data, None, None 

1898 else: 

1899 if raise_if_ambiguous: 

1900 raise ValueError( 

1901 'When passing a `validation_data` argument, ' 

1902 'it must contain either 2 items (x_val, y_val), ' 

1903 'or 3 items (x_val, y_val, val_sample_weights), ' 

1904 'or alternatively it could be a dataset or a ' 

1905 'dataset or a dataset iterator. ' 

1906 'However we received `validation_data=%s`' % validation_data) 

1907 val_x, val_y, val_sample_weight = validation_data, None, None 

1908 return val_x, val_y, val_sample_weight 

1909 

1910 

1911class TrainingLoop(object): 

1912 """TrainingLoop is a wrapper class around the training logic. 

1913 

1914 This class is trying to encapsulate the different logic of fit/eval/predict 

1915 with regard to different data input and model condition. 

1916 

1917 Note that TrainingLoop is stateless, which means it doesn't contain any 

1918 internal field and can be reused with different model and inputs. 

1919 """ 

1920 

1921 def fit(self, 

1922 model, 

1923 x=None, 

1924 y=None, 

1925 batch_size=None, 

1926 epochs=1, 

1927 verbose=1, 

1928 callbacks=None, 

1929 validation_split=0., 

1930 validation_data=None, 

1931 shuffle=True, 

1932 class_weight=None, 

1933 sample_weight=None, 

1934 initial_epoch=0, 

1935 steps_per_epoch=None, 

1936 validation_steps=None, 

1937 validation_freq=1, 

1938 **kwargs): 

1939 """Train the model with the inputs and targets.""" 

1940 raise NotImplementedError() 

1941 

1942 def evaluate(self, 

1943 model, 

1944 x=None, 

1945 y=None, 

1946 batch_size=None, 

1947 verbose=1, 

1948 sample_weight=None, 

1949 steps=None, 

1950 callbacks=None, 

1951 **kwargs): 

1952 """Returns the loss value & metrics values for the model in test mode.""" 

1953 raise NotImplementedError() 

1954 

1955 def predict(self, 

1956 model, 

1957 x, 

1958 batch_size=None, 

1959 verbose=0, 

1960 steps=None, 

1961 callbacks=None, 

1962 **kwargs): 

1963 raise NotImplementedError()