Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_utils_v1.py: 17%
782 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training-related utilities."""
17import abc
18import atexit
19import collections
20import functools
21import multiprocessing.pool
22import threading
23import time
25import numpy as np
27from tensorflow.core.framework import graph_pb2
28from tensorflow.python import tf2
29from tensorflow.python.data.experimental.ops import cardinality
30from tensorflow.python.data.ops import dataset_ops
31from tensorflow.python.data.ops import iterator_ops
32from tensorflow.python.data.ops import options as options_lib
33from tensorflow.python.eager import context
34from tensorflow.python.framework import composite_tensor
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import smart_cond
38from tensorflow.python.framework import sparse_tensor
39from tensorflow.python.framework import tensor_conversion
40from tensorflow.python.framework import tensor_spec
41from tensorflow.python.framework import tensor_util
42from tensorflow.python.keras import backend
43from tensorflow.python.keras import callbacks as cbks
44from tensorflow.python.keras import losses
45from tensorflow.python.keras import metrics as metrics_module
46from tensorflow.python.keras.utils import data_utils
47from tensorflow.python.keras.utils import generic_utils
48from tensorflow.python.keras.utils import losses_utils
49from tensorflow.python.keras.utils import tf_inspect
50from tensorflow.python.ops import array_ops
51from tensorflow.python.ops import gen_array_ops
52from tensorflow.python.ops import math_ops
53from tensorflow.python.ops import sparse_ops
54from tensorflow.python.ops.ragged import ragged_tensor
55from tensorflow.python.ops.ragged import ragged_tensor_value
56from tensorflow.python.platform import tf_logging as logging
57from tensorflow.python.types import data as data_types
58from tensorflow.python.util import nest
61def is_composite_or_composite_value(tensor):
62 """Returns true if 'tensor' is a CompositeTensor or a CT Value object."""
63 # TODO(b/125094323): This should be isinstance(CompositeTensor) or
64 # isinstance(CompositeTensorValue) once we support that.
65 return isinstance(
66 tensor,
67 (composite_tensor.CompositeTensor, sparse_tensor.SparseTensorValue,
68 ragged_tensor_value.RaggedTensorValue))
71class Aggregator(object, metaclass=abc.ABCMeta):
72 """Abstract base class used to aggregate batch-level outputs of a loop.
74 Attributes:
75 use_steps: Whether the loop is using `step` or `batch_size`.
76 num_samples: Total number of samples: `batch_size * num_batches`.
77 steps: Total number of steps.
78 batch_size: Batch size. It is used for validation checks between inputs and
79 outputs.
80 results: What to return at the end of the aggregation loop.
81 """
83 def __init__(self, use_steps, num_samples=None, steps=None, batch_size=None):
84 self.use_steps = use_steps
85 self.num_samples = num_samples
86 self.steps = steps
87 self.batch_size = batch_size
88 self.results = []
90 @abc.abstractmethod
91 def create(self, batch_outs):
92 """Creates the initial results from the first batch outputs.
94 Args:
95 batch_outs: A list of batch-level outputs.
96 """
97 raise NotImplementedError('Must be implemented in subclasses.')
99 @abc.abstractmethod
100 def aggregate(self, batch_outs, batch_start=None, batch_end=None):
101 """Aggregates batch-level results into total results.
103 Args:
104 batch_outs: A list of batch-level outputs.
105 batch_start: The start index of this batch. Always `None` if `use_steps`
106 is `True`.
107 batch_end: The end index of this batch. Always `None` if `use_steps` is
108 `True`.
109 """
110 raise NotImplementedError('Must be implemented in subclasses.')
112 @abc.abstractmethod
113 def finalize(self):
114 """Prepares the total results to be returned."""
115 raise NotImplementedError('Must be implemented in subclasses.')
118class MetricsAggregator(Aggregator):
119 """Aggregator that calculates loss and metrics info.
121 Attributes:
122 use_steps: Whether the loop is using `step` or `batch_size`.
123 num_samples: Total number of samples: `batch_size*num_batches`.
124 steps: Total number of steps, ie number of times to iterate over a dataset
125 to cover all samples.
126 """
128 def __init__(self, use_steps, num_samples=None, steps=None):
129 super(MetricsAggregator, self).__init__(
130 use_steps=use_steps,
131 num_samples=num_samples,
132 steps=steps,
133 batch_size=None)
135 def create(self, batch_outs):
136 self.results = [0.] * len(batch_outs)
138 def aggregate(self, batch_outs, batch_start=None, batch_end=None):
139 # Loss.
140 if self.use_steps:
141 self.results[0] += batch_outs[0]
142 else:
143 self.results[0] += batch_outs[0] * (batch_end - batch_start)
144 # Metrics (always stateful, just grab current values.)
145 self.results[1:] = batch_outs[1:]
147 def finalize(self):
148 if not self.results:
149 raise ValueError('Empty training data.')
150 self.results[0] /= (self.num_samples or self.steps)
153def _append_sparse_tensor_value(target, to_append):
154 """Append sparse tensor value objects."""
155 # Make sure the sparse tensors are of the same size (except for the 0th dim).
156 if len(target.dense_shape) != len(to_append.dense_shape):
157 raise RuntimeError(
158 'Unable to concatenate %s and %s. The inner dense shapes do not '
159 'have the same number of dimensions (%s vs %s)' %
160 (target, to_append, target.dense_shape, to_append.dense_shape))
162 if target.dense_shape[1:] != to_append.dense_shape[1:]:
163 raise RuntimeError(
164 'Unable to concatenate %s and %s. The inner dense shapes do not '
165 'match inner dimensions (%s vs %s)' %
166 (target, to_append, target.dense_shape[1:], to_append.dense_shape[1:]))
168 # Add the to_append indices to target, updating the 0th value, and keeping
169 # track of the maximum so we know the final dense_shape of this tensor.
170 base_dim0_value = target.dense_shape[0]
171 max_dim0_value = target.dense_shape[0]
172 new_indices = target.indices
173 for index in to_append.indices:
174 # Here, we iterate through the sparse indices of the tensor to append. For
175 # each index, we update its zeroth value (the batch index) by adding the
176 # number of batch items in the tensor we are appending to (so an index
177 # of [0, 0, 1] for a value that is being appended to a tensor with 0th dim
178 # size 3 would become [3, 0, 1].)
179 index[0] += base_dim0_value
180 max_dim0_value = max(max_dim0_value, index[0])
181 new_indices = np.append(new_indices, [index], axis=0)
183 # Extend the values array to contain all of the appended values. These will
184 # be in the same order as the indices added above.
185 new_values = np.concatenate((target.values, to_append.values), axis=0)
187 # Create a new dense shape by replacing the value for the 0th dimension
188 # with the new max dim0 value.
189 new_dense_shape = list(target.dense_shape)
190 new_dense_shape[0] = max_dim0_value + 1
191 new_dense_shape = tuple(new_dense_shape)
193 return sparse_tensor.SparseTensorValue(
194 indices=new_indices, values=new_values, dense_shape=new_dense_shape)
197def _append_ragged_tensor_value(target, to_append):
198 """Append ragged tensor value objects."""
199 # Make sure the ragged tensors are of the same size (save for the 0th dim).
200 if len(target.shape) != len(to_append.shape):
201 raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append))
203 if target.shape[1:] != to_append.shape[1:]:
204 raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append))
206 adjusted_row_splits = to_append.row_splits[1:] + target.row_splits[-1]
207 new_row_splits = np.append(target.row_splits, adjusted_row_splits)
208 if isinstance(target.values, ragged_tensor_value.RaggedTensorValue):
209 new_values = _append_ragged_tensor_value(target.values, to_append.values)
210 else:
211 new_values = np.concatenate((target.values, to_append.values), axis=0)
213 return ragged_tensor_value.RaggedTensorValue(new_values, new_row_splits)
216def _append_composite_tensor(target, to_append):
217 """Helper function to append composite tensors to each other in the 0 axis.
219 In order to support batching within a fit/evaluate/predict call, we need
220 to be able to aggregate within a CompositeTensor. Unfortunately, the CT
221 API currently does not make this easy - especially in V1 mode, where we're
222 working with CompositeTensor Value objects that have no connection with the
223 CompositeTensors that created them.
225 Args:
226 target: CompositeTensor or CompositeTensor value object that will be
227 appended to.
228 to_append: CompositeTensor or CompositeTensor value object to append to.
229 'target'.
231 Returns:
232 A CompositeTensor or CompositeTensor value object.
234 Raises:
235 RuntimeError: if concatenation is not possible.
236 """
237 if type(target) is not type(to_append):
238 raise RuntimeError('Unable to concatenate %s and %s' %
239 (type(target), type(to_append)))
241 # Perform type-specific concatenation.
242 # TODO(b/125094323): This should be replaced by a simple call to
243 # target.append() that should work on all of the below classes.
245 # If we're seeing a CompositeTensor here, we know it's because we're in
246 # Eager mode (or else we'd have evaluated the CT to a CT Value object
247 # already). Therefore, it's safe to call concat() on it without evaluating
248 # the result any further. If not - that is, if we're seeing a
249 # SparseTensorValue or a RaggedTensorValue - we need to hand-update it
250 # since we're outside of the graph anyways.
251 if isinstance(target, sparse_tensor.SparseTensor):
252 # We need to invoke the sparse version of concatenate here - tf.concat
253 # won't work.
254 return sparse_ops.sparse_concat(sp_inputs=[target, to_append], axis=0)
255 elif isinstance(target, ragged_tensor.RaggedTensor):
256 return array_ops.concat([target, to_append], axis=0)
257 elif isinstance(target, sparse_tensor.SparseTensorValue):
258 return _append_sparse_tensor_value(target, to_append)
259 elif isinstance(target, ragged_tensor_value.RaggedTensorValue):
260 return _append_ragged_tensor_value(target, to_append)
261 else:
262 raise RuntimeError('Attempted to concatenate unsupported object %s.' %
263 type(target))
266class ConcatAggregator(Aggregator):
267 """Combine tensor-likes which cannot be merged on the fly.
269 This class expects to aggregate a single tensor-like rather than a nested
270 structure of tensor-likes.
271 """
273 def __init__(self, batch_size):
274 self.composite = None
275 super(ConcatAggregator, self).__init__(
276 use_steps=True, num_samples=None, steps=None, batch_size=batch_size)
278 def create(self, batch_element):
279 self.composite = is_composite_or_composite_value(batch_element)
281 def aggregate(self, batch_element, batch_start=None, batch_end=None):
283 # TODO(psv): Add num_samples check here to detect when output batch
284 # #samples is < batch size and != input batch #samples.
285 if self.batch_size and self.batch_size < batch_element.shape[0]:
286 raise ValueError(
287 'Mismatch between expected batch size and model output batch size. '
288 'Output shape = {}, expected output shape = shape {}'.format(
289 batch_element.shape,
290 (self.batch_size,) + batch_element.shape[1:]))
291 self.results.append(batch_element)
293 def finalize(self):
294 # Special case of single batch inference which skips a copy.
295 if len(self.results) == 1:
296 self.results = self.results[0]
298 elif self.composite:
299 # TODO(taylorrobie): efficiently concatenate.
300 results = self.results[0]
301 for r in self.results[1:]:
302 results = _append_composite_tensor(results, r)
303 self.results = results
305 else:
306 self.results = np.concatenate(self.results, axis=0)
309_COPY_THREADS = 4
310_COPY_POOL = None
313def get_copy_pool():
314 """Shared threadpool for copying arrays.
316 Pool instantiation takes ~ 2ms, so a singleton pool is used rather than
317 creating a pool per SliceAggregator.
319 Returns:
320 The global copy threadpool.
321 """
322 global _COPY_POOL
323 if _COPY_POOL is None:
324 _COPY_POOL = multiprocessing.pool.ThreadPool(_COPY_THREADS)
325 atexit.register(_COPY_POOL.close)
326 return _COPY_POOL
329class SliceAggregator(Aggregator):
330 """Combine arrays where the final size is known.
332 This class expects to aggregate a single tensor-like rather than a nested
333 structure of tensor-likes.
335 NumPy copies are an operation that threads handle quite well because all of
336 the heavy lifting is in c and does not need the GIL. Moreover, we can perform
337 lock-free writes to the same buffer in multiple threads because the nature of
338 result aggregation guarantees that either the indices are disjoint or the
339 aggregator will throw an exception in finalize. Moreover, because aggregation
340 is performed on the slowest varying dimension, assignments for a given batch
341 will write to contiguous blocks of memory, further minimizing contention.
343 There is, however, some scheduling and context switching overhead which will
344 offset the gains from pipelining the slice assignment. Below a given threshold
345 it is faster to simply assign in the main thread rather than enqueue the
346 assignment in a side thread. The exact threshold will vary from system to
347 system, but the time is not very sensitive to the exact transition so a value
348 of 2 ** 14 was chosen which should be reasonable on most systems.
349 """
351 _BINARY_SIZE_THRESHOLD = 2 ** 14
352 _MAX_COPY_SECONDS = 300
354 def __init__(self, num_samples, batch_size):
355 self._async_copies = []
356 self._pool = get_copy_pool()
357 self._errors = []
358 super(SliceAggregator, self).__init__(
359 use_steps=False,
360 num_samples=num_samples,
361 steps=None,
362 batch_size=batch_size)
364 def create(self, batch_element):
365 # This step does not need to be pipelined because NumPy empty array
366 # initialization is effectively instantaneous.
367 shape = (self.num_samples,) + batch_element.shape[1:]
368 dtype = batch_element.dtype
370 self.results = np.empty(shape=shape, dtype=dtype)
372 def aggregate(self, batch_element, batch_start, batch_end):
373 # Fail early.
374 if self._errors:
375 raise self._errors[0]
377 # In the special case of single batch inference, no copy is needed.
378 if batch_end - batch_start == self.num_samples:
379 if self.num_samples != batch_element.shape[0]:
380 raise ValueError(
381 'Mismatch between expected batch size and model output batch size. '
382 'Output shape = {}, expected output shape = shape {}'.format(
383 batch_element.shape, self.results.shape))
385 self.results = batch_element
386 return
388 # This is an approximate threshold, so we don't need to consider the number
389 # of bytes per element.
390 num_elements = np.prod(batch_element.shape)
391 if num_elements < self._BINARY_SIZE_THRESHOLD:
392 self.results[batch_start:batch_end] = batch_element
393 else:
394 is_finished = threading.Event()
395 self._pool.apply_async(
396 self._slice_assign,
397 args=(batch_element, batch_start, batch_end, is_finished))
398 self._async_copies.append(is_finished)
400 def _slice_assign(self, batch_element, batch_start, batch_end, is_finished):
401 """Legacy utility method to slice input arrays."""
402 try:
403 self.results[batch_start:batch_end] = batch_element
405 except Exception as e: # pylint: disable=broad-except
406 # `_slice_assign` should only be called in threads and exceptions raised
407 # in threads do not carry over to the main thread. So instead we perform a
408 # a broad catch in the thread and then store the exception to be re-raised
409 # in the main thread.
410 self._errors.append(e)
412 finally:
413 is_finished.set()
415 def finalize(self):
416 start_time = time.time()
417 for is_finished in self._async_copies:
418 timeout = max([0., self._MAX_COPY_SECONDS - (time.time() - start_time)])
419 if not is_finished.wait(timeout):
420 raise ValueError('Timed out waiting for copy to complete.')
422 if self._errors:
423 raise self._errors[0]
426class OutputsAggregator(Aggregator):
427 """Aggregator that concatenates outputs."""
429 _structure = None
431 def create(self, batch_outs):
432 # SparseTensorValue is a named tuple which nest will flatten, so we need
433 # to guard it to properly handle the structure.
434 self._structure = nest.get_traverse_shallow_structure(
435 lambda x: not is_composite_or_composite_value(x), batch_outs)
436 batch_outs = nest.flatten_up_to(self._structure, batch_outs)
438 for batch_element in batch_outs:
439 if is_composite_or_composite_value(batch_element):
440 # If the output is not a ndarray, it will be either a composite tensor
441 # or a composite tensor's Value object. In either case, we can't
442 # allocate an array to hold the object - we'll handle it later.
443 self.results.append(ConcatAggregator(self.batch_size))
444 elif isinstance(batch_element, np.ndarray):
445 self.results.append(
446 (ConcatAggregator(self.batch_size) if self.use_steps else
447 SliceAggregator(self.num_samples, self.batch_size)))
448 else:
449 # This is not a ndarray, a CompositeTensor, or a CompositeTensorValue.
450 # Fail fast rather than trying to concatenate it.
451 raise RuntimeError('Attempted to aggregate unsupported object {}.'
452 .format(batch_element))
454 self.results[-1].create(batch_element)
456 def aggregate(self, batch_outs, batch_start=None, batch_end=None):
457 batch_outs = nest.flatten_up_to(self._structure, batch_outs)
458 for batch_element, result in zip(batch_outs, self.results):
459 result.aggregate(batch_element, batch_start, batch_end)
461 def finalize(self):
462 for result in self.results:
463 result.finalize()
464 self.results = [i.results for i in self.results]
465 self.results = nest.pack_sequence_as(self._structure, self.results)
468def get_progbar(model, count_mode, include_metrics=True):
469 """Get Progbar."""
470 if include_metrics:
471 stateful_metric_names = getattr(model, 'metrics_names', None)
472 if stateful_metric_names:
473 stateful_metric_names = stateful_metric_names[1:] # Exclude `loss`
474 else:
475 stateful_metric_names = None
476 return cbks.ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names)
479def check_num_samples(ins, batch_size=None, steps=None, steps_name='steps'):
480 """Determine the number of samples provided for training and evaluation.
482 The number of samples is not defined when running with `steps`,
483 in which case the number of samples is set to `None`.
485 Args:
486 ins: List of tensors to be fed to the Keras function.
487 batch_size: Integer batch size or `None` if not defined.
488 steps: Total number of steps (batches of samples) before declaring
489 `_predict_loop` finished. Ignored with the default value of `None`.
490 steps_name: The public API's parameter name for `steps`.
492 Raises:
493 ValueError: when `steps` is `None` and the attribute `ins.shape`
494 does not exist. Also raises ValueError when `steps` is not `None`
495 and `batch_size` is not `None` because they are mutually
496 exclusive.
498 Returns:
499 When steps is `None`, returns the number of samples to be
500 processed based on the size of the first dimension of the
501 first input numpy array. When steps is not `None` and
502 `batch_size` is `None`, returns `None`.
503 """
504 if steps is not None and batch_size is not None:
505 raise ValueError('If ' + steps_name +
506 ' is set, the `batch_size` must be None.')
507 if check_steps_argument(ins, steps, steps_name):
508 return None
510 if hasattr(ins[0], 'shape'):
511 return int(ins[0].shape[0])
512 return None # Edge case where ins == [static_learning_phase]
515def standardize_single_array(x, expected_shape=None):
516 """Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1."""
517 if x is None:
518 return None
520 if is_composite_or_composite_value(x):
521 return x
523 if isinstance(x, int):
524 raise ValueError(
525 'Expected an array data type but received an integer: {}'.format(x))
527 if (x.shape is not None and len(x.shape) == 1 and
528 (expected_shape is None or len(expected_shape) != 1)):
529 if tensor_util.is_tf_type(x):
530 x = array_ops.expand_dims(x, axis=1)
531 else:
532 x = np.expand_dims(x, 1)
533 return x
536def get_composite_shape(tensor):
537 """Returns the shape of the passed composite tensor."""
538 if isinstance(tensor, sparse_tensor.SparseTensorValue):
539 # SparseTensorValues use a 'dense_shape' attribute
540 return tensor.dense_shape
541 else:
542 return tensor.shape
545def standardize_input_data(data,
546 names,
547 shapes=None,
548 check_batch_axis=True,
549 exception_prefix=''):
550 """Normalizes inputs and targets provided by users.
552 Users may pass data as a list of arrays, dictionary of arrays,
553 or as a single array. We normalize this to an ordered list of
554 arrays (same order as `names`), while checking that the provided
555 arrays have shapes that match the network's expectations.
557 Args:
558 data: User-provided input data (polymorphic).
559 names: List of expected array names.
560 shapes: Optional list of expected array shapes.
561 check_batch_axis: Boolean; whether to check that the batch axis of the
562 arrays matches the expected value found in `shapes`.
563 exception_prefix: String prefix used for exception formatting.
565 Returns:
566 List of standardized input arrays (one array per model input).
568 Raises:
569 ValueError: in case of improperly formatted user-provided data.
570 """
571 try:
572 data_len = len(data)
573 except TypeError:
574 # For instance if data is `None` or a symbolic Tensor.
575 data_len = None
577 if not names:
578 if data_len and not isinstance(data, dict):
579 raise ValueError(
580 'Error when checking model ' + exception_prefix + ': '
581 'expected no data, but got:', data)
582 return []
583 if data is None:
584 return [None for _ in range(len(names))]
586 if isinstance(data, dict):
587 try:
588 data = [
589 data[x].values
590 if data[x].__class__.__name__ == 'DataFrame' else data[x]
591 for x in names
592 ]
593 except KeyError as e:
594 raise ValueError('No data provided for "' + e.args[0] + '". Need data '
595 'for each key in: ' + str(names))
596 elif isinstance(data, (list, tuple)):
597 if isinstance(data[0], (list, tuple)):
598 data = [np.asarray(d) for d in data]
599 elif len(names) == 1 and isinstance(data[0], (float, int)):
600 data = [np.asarray(data)]
601 else:
602 data = [
603 x.values if x.__class__.__name__ == 'DataFrame' else x for x in data
604 ]
605 else:
606 data = data.values if data.__class__.__name__ == 'DataFrame' else data
607 data = [data]
609 if shapes is not None:
610 data = [
611 standardize_single_array(x, shape) for (x, shape) in zip(data, shapes)
612 ]
613 else:
614 data = [standardize_single_array(x) for x in data]
616 if len(data) != len(names):
617 if data and hasattr(data[0], 'shape'):
618 raise ValueError('Error when checking model ' + exception_prefix +
619 ': the list of Numpy arrays that you are passing to '
620 'your model is not the size the model expected. '
621 'Expected to see ' + str(len(names)) + ' array(s), ' +
622 'for inputs ' + str(names) + ' but instead got the '
623 'following list of ' + str(len(data)) + ' arrays: ' +
624 str(data)[:200] + '...')
625 elif len(names) > 1:
626 raise ValueError('Error when checking model ' + exception_prefix +
627 ': you are passing a list as input to your model, '
628 'but the model expects a list of ' + str(len(names)) +
629 ' Numpy arrays instead. The list you passed was: ' +
630 str(data)[:200])
631 elif len(data) == 1 and not hasattr(data[0], 'shape'):
632 raise TypeError('Error when checking model ' + exception_prefix +
633 ': data should be a Numpy array, or list/dict of '
634 'Numpy arrays. Found: ' + str(data)[:200] + '...')
635 elif len(names) == 1:
636 data = [np.asarray(data)]
638 # Check shapes compatibility.
639 if shapes:
640 for i in range(len(names)):
641 if shapes[i] is not None:
642 if tensor_util.is_tf_type(data[i]):
643 tensorshape = data[i].shape
644 if not tensorshape:
645 continue
646 data_shape = tuple(tensorshape.as_list())
647 elif is_composite_or_composite_value(data[i]):
648 tensorshape = get_composite_shape(data[i])
649 data_shape = tuple(tensorshape.as_list())
650 else:
651 data_shape = data[i].shape
653 shape = shapes[i]
654 if len(data_shape) != len(shape):
655 raise ValueError('Error when checking ' + exception_prefix +
656 ': expected ' + names[i] + ' to have ' +
657 str(len(shape)) + ' dimensions, but got array '
658 'with shape ' + str(data_shape))
659 if not check_batch_axis:
660 data_shape = data_shape[1:]
661 shape = shape[1:]
662 for dim, ref_dim in zip(data_shape, shape):
663 if ref_dim != dim and ref_dim is not None and dim is not None:
664 raise ValueError('Error when checking ' + exception_prefix +
665 ': expected ' + names[i] + ' to have shape ' +
666 str(shape) + ' but got array with shape ' +
667 str(data_shape))
668 return data
671def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
672 """Maps `sample_weight` or `class_weight` to model outputs.
674 Args:
675 x_weight: User-provided `sample_weight` or `class_weight` argument.
676 output_names: List of output names (strings) in the model.
677 weight_type: A string used purely for exception printing.
679 Returns:
680 A list of `sample_weight` or `class_weight` where there are exactly
681 one element per model output.
683 Raises:
684 ValueError: In case of invalid user-provided argument.
685 """
686 if x_weight is None or (isinstance(x_weight, (list, tuple)) and
687 len(x_weight) == 0): # pylint: disable=g-explicit-length-test
688 return [None for _ in output_names]
689 if len(output_names) == 1:
690 if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1:
691 return x_weight
692 if isinstance(x_weight, dict) and output_names[0] in x_weight:
693 return [x_weight[output_names[0]]]
694 else:
695 return [x_weight]
696 if isinstance(x_weight, (list, tuple)):
697 if len(x_weight) != len(output_names):
698 raise ValueError('Provided `' + weight_type + '` was a list of ' +
699 str(len(x_weight)) + ' elements, but the model has ' +
700 str(len(output_names)) + ' outputs. '
701 'You should provide one `' + weight_type + '`'
702 'array per model output.')
703 return x_weight
704 if isinstance(x_weight, collections.abc.Mapping):
705 generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names)
706 x_weights = []
707 for name in output_names:
708 x_weights.append(x_weight.get(name))
709 return x_weights
710 else:
711 raise TypeError('The model has multiple outputs, so `' + weight_type + '` '
712 'should be either a list or a dict. '
713 'Provided `' + weight_type + '` type not understood: ' +
714 str(x_weight))
717def standardize_class_weights(class_weight, output_names):
718 return standardize_sample_or_class_weights(class_weight, output_names,
719 'class_weight')
722def standardize_sample_weights(sample_weight, output_names):
723 return standardize_sample_or_class_weights(sample_weight, output_names,
724 'sample_weight')
727def check_array_lengths(inputs, targets, weights=None):
728 """Does user input validation for numpy arrays.
730 Args:
731 inputs: list of Numpy arrays of inputs.
732 targets: list of Numpy arrays of targets.
733 weights: list of Numpy arrays of sample weights.
735 Raises:
736 ValueError: in case of incorrectly formatted data.
737 """
739 def is_tensor_or_composite_tensor(x):
740 return tensor_util.is_tf_type(x) or is_composite_or_composite_value(x)
742 def set_of_lengths(x):
743 # Returns a set with the variation between
744 # different shapes, with None => 0
745 if x is None:
746 return {}
747 else:
748 return set([
749 y.shape[0]
750 for y in x
751 if y is not None and not is_tensor_or_composite_tensor(y)
752 ])
754 set_x = set_of_lengths(inputs)
755 set_y = set_of_lengths(targets)
756 set_w = set_of_lengths(weights)
757 if len(set_x) > 1:
758 raise ValueError('All input arrays (x) should have '
759 'the same number of samples. Got array shapes: ' +
760 str([x.shape for x in inputs]))
761 if len(set_y) > 1:
762 raise ValueError('All target arrays (y) should have '
763 'the same number of samples. Got array shapes: ' +
764 str([y.shape for y in targets]))
765 if set_x and set_y and list(set_x)[0] != list(set_y)[0]:
766 raise ValueError('Input arrays should have '
767 'the same number of samples as target arrays. '
768 'Found ' + str(list(set_x)[0]) + ' input samples '
769 'and ' + str(list(set_y)[0]) + ' target samples.')
770 if len(set_w) > 1:
771 raise ValueError('All sample_weight arrays should have '
772 'the same number of samples. Got array shapes: ' +
773 str([w.shape for w in weights]))
774 if set_y and set_w and list(set_y)[0] != list(set_w)[0]:
775 raise ValueError('Sample_weight arrays should have '
776 'the same number of samples as target arrays. Got ' +
777 str(list(set_y)[0]) + ' input samples and ' +
778 str(list(set_w)[0]) + ' target samples.')
781def check_loss_and_target_compatibility(targets, loss_fns, output_shapes):
782 """Does validation on the compatibility of targets and loss functions.
784 This helps prevent users from using loss functions incorrectly. This check
785 is purely for UX purposes.
787 Args:
788 targets: list of Numpy arrays of targets.
789 loss_fns: list of loss functions.
790 output_shapes: list of shapes of model outputs.
792 Raises:
793 ValueError: if a loss function or target array
794 is incompatible with an output.
795 """
796 key_loss_fns = {
797 losses.mean_squared_error, losses.binary_crossentropy,
798 losses.categorical_crossentropy
799 }
800 key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy,
801 losses.CategoricalCrossentropy)
802 for y, loss, shape in zip(targets, loss_fns, output_shapes):
803 if y is None or loss is None or tensor_util.is_tf_type(y):
804 continue
805 if losses.is_categorical_crossentropy(loss):
806 if y.shape[-1] == 1:
807 raise ValueError('You are passing a target array of shape ' +
808 str(y.shape) +
809 ' while using as loss `categorical_crossentropy`. '
810 '`categorical_crossentropy` expects '
811 'targets to be binary matrices (1s and 0s) '
812 'of shape (samples, classes). '
813 'If your targets are integer classes, '
814 'you can convert them to the expected format via:\n'
815 '```\n'
816 'from keras.utils import to_categorical\n'
817 'y_binary = to_categorical(y_int)\n'
818 '```\n'
819 '\n'
820 'Alternatively, you can use the loss function '
821 '`sparse_categorical_crossentropy` instead, '
822 'which does expect integer targets.')
824 is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper)
825 if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and
826 (loss.fn in key_loss_fns))):
827 for target_dim, out_dim in zip(y.shape[1:], shape[1:]):
828 if out_dim is not None and target_dim != out_dim:
829 loss_name = loss.name
830 if loss_name is None:
831 loss_type = loss.fn if is_loss_wrapper else type(loss)
832 loss_name = loss_type.__name__
833 raise ValueError('A target array with shape ' + str(y.shape) +
834 ' was passed for an output of shape ' + str(shape) +
835 ' while using as loss `' + loss_name + '`. '
836 'This loss expects targets to have the same shape '
837 'as the output.')
840def collect_per_output_metric_info(metrics,
841 output_names,
842 output_shapes,
843 loss_fns,
844 from_serialized=False,
845 is_weighted=False):
846 """Maps metric names and functions to model outputs.
848 Args:
849 metrics: a list or a list of lists or a dict of metric functions.
850 output_names: a list of the names (strings) of model outputs.
851 output_shapes: a list of the shapes (strings) of model outputs.
852 loss_fns: a list of the loss functions corresponding to the model outputs.
853 from_serialized: whether the model the metrics are being sourced from is
854 being initialized from a serialized format.
855 is_weighted: Boolean indicating whether the given metrics are weighted.
857 Returns:
858 A list (one entry per model output) of dicts.
859 For instance, if the model has 2 outputs, and for the first output
860 we want to compute "binary_accuracy" and "binary_crossentropy",
861 and just "binary_accuracy" for the second output,
862 the list would look like: `[{
863 'acc': binary_accuracy(),
864 'ce': binary_crossentropy(),
865 }, {
866 'acc': binary_accuracy(),
867 }]`
869 Raises:
870 TypeError: if an incorrect type is passed for the `metrics` argument.
871 """
872 if not metrics:
873 return [{} for _ in output_names]
875 if isinstance(metrics, list):
876 any_sub_list = any(isinstance(m, list) for m in metrics)
877 if any_sub_list:
878 if len(metrics) != len(output_names):
879 raise ValueError('When passing a list of lists as `metrics`, '
880 'it should have one entry per model output. '
881 'The model has ' + str(len(output_names)) +
882 ' outputs, but you passed metrics=' + str(metrics))
883 # User has provided a list of len = len(outputs).
884 nested_metrics = [generic_utils.to_list(m) for m in metrics]
885 else:
886 # If it is a single list we then apply all metrics to all outputs.
887 if len(output_names) > 1:
888 nested_metrics = []
889 for _ in output_names:
890 nested_metrics.append(
891 [metrics_module.clone_metric(m) for m in metrics])
892 else:
893 nested_metrics = [metrics]
894 elif isinstance(metrics, collections.abc.Mapping):
895 generic_utils.check_for_unexpected_keys('metrics', metrics, output_names)
896 nested_metrics = []
897 for name in output_names:
898 output_metrics = generic_utils.to_list(metrics.get(name, []))
899 nested_metrics.append(output_metrics)
900 else:
901 raise TypeError('Type of `metrics` argument not understood. '
902 'Expected a list or dictionary, found: ' + str(metrics))
904 per_output_metrics = []
905 for i, metrics in enumerate(nested_metrics):
906 metrics_dict = collections.OrderedDict()
907 for metric in metrics:
908 metric_name = get_metric_name(metric, is_weighted)
909 metric_fn = get_metric_function(
910 metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])
911 metric_fn._from_serialized = from_serialized # pylint: disable=protected-access
913 # If the metric function is not stateful, we create a stateful version.
914 if not isinstance(metric_fn, metrics_module.Metric):
915 metric_fn = metrics_module.MeanMetricWrapper(
916 metric_fn, name=metric_name)
917 # If the metric is being revived from something stateless, such as a
918 # string (e.g. "accuracy"), we may need to later reapply transformations
919 # such as renaming.
920 metric_fn._from_serialized = False # pylint: disable=protected-access
921 metrics_dict[metric_name] = metric_fn
922 per_output_metrics.append(metrics_dict)
924 return per_output_metrics
927def batch_shuffle(index_array, batch_size):
928 """Shuffles an array in a batch-wise fashion.
930 Useful for shuffling HDF5 arrays
931 (where one cannot access arbitrary indices).
933 Args:
934 index_array: array of indices to be shuffled.
935 batch_size: integer.
937 Returns:
938 The `index_array` array, shuffled in a batch-wise fashion.
939 """
940 batch_count = int(len(index_array) / batch_size)
941 # to reshape we need to be cleanly divisible by batch size
942 # we stash extra items and reappend them after shuffling
943 last_batch = index_array[batch_count * batch_size:]
944 index_array = index_array[:batch_count * batch_size]
945 index_array = index_array.reshape((batch_count, batch_size))
946 np.random.shuffle(index_array)
947 index_array = index_array.flatten()
948 return np.append(index_array, last_batch)
951def standardize_weights(y,
952 sample_weight=None,
953 class_weight=None,
954 sample_weight_mode=None):
955 """Performs sample weight validation and standardization.
957 Everything gets normalized to a single sample-wise (or timestep-wise)
958 weight array. If both `sample_weight` and `class_weight` are provided,
959 the weights are multiplied.
961 Args:
962 y: Numpy array or Tensor of model targets to be weighted.
963 sample_weight: User-provided `sample_weight` argument.
964 class_weight: User-provided `class_weight` argument.
965 sample_weight_mode: One of `None` or `"temporal"`. `"temporal"` indicated
966 that we expect 2D weight data that will be applied to the last 2
967 dimensions of the targets (i.e. we are weighting timesteps, not
968 samples).
970 Returns:
971 A numpy array of target weights, one entry per sample to weight.
973 Raises:
974 ValueError: In case of invalid user-provided arguments.
975 """
976 # Iterator may return sample_weight as 1-tuple
977 if isinstance(sample_weight, tuple):
978 sample_weight = sample_weight[0]
979 if sample_weight_mode is not None and sample_weight_mode != 'samplewise':
980 if sample_weight_mode != 'temporal':
981 raise ValueError('"sample_weight_mode '
982 'should be None or "temporal". '
983 'Found: ' + str(sample_weight_mode))
984 if len(y.shape) < 3:
985 raise ValueError('Found a sample_weight array for '
986 'an input with shape ' + str(y.shape) + '. '
987 'Timestep-wise sample weighting (use of '
988 'sample_weight_mode="temporal") is restricted to '
989 'outputs that are at least 3D, i.e. that have '
990 'a time dimension.')
991 if sample_weight is not None and len(sample_weight.shape) != 2:
992 raise ValueError('Found a sample_weight array with shape ' +
993 str(sample_weight.shape) + '. '
994 'In order to use timestep-wise sample weighting, '
995 'you should pass a 2D sample_weight array.')
996 else:
997 if sample_weight is not None and len(sample_weight.shape) != 1:
998 raise ValueError(
999 'Found a sample_weight array with shape {}. In order to '
1000 'use timestep-wise sample weights, you should specify '
1001 'sample_weight_mode="temporal" in compile(); founssd "{}" '
1002 'instead. If you just mean to use sample-wise weights, '
1003 'make sure your sample_weight array is 1D.'.format(
1004 sample_weight.shape, sample_weight_mode))
1006 if sample_weight is not None:
1007 if len(sample_weight.shape) > len(y.shape):
1008 raise ValueError('Found a sample_weight with shape' +
1009 str(sample_weight.shape) + '.'
1010 'Expected sample_weight with rank '
1011 'less than or equal to ' + str(len(y.shape)))
1013 if (not tensor_util.is_tf_type(sample_weight) and
1014 y.shape[:sample_weight.ndim] != sample_weight.shape):
1015 raise ValueError('Found a sample_weight array with shape ' +
1016 str(sample_weight.shape) + ' for an input with shape ' +
1017 str(y.shape) + '. '
1018 'sample_weight cannot be broadcast.')
1020 # Class weights applied per-sample.
1021 class_sample_weight = None
1022 if isinstance(class_weight, dict):
1023 if len(y.shape) > 2:
1024 raise ValueError('`class_weight` not supported for '
1025 '3+ dimensional targets.')
1027 if tensor_util.is_tf_type(y):
1028 # Few classes are expected, so densifying is reasonable.
1029 keys = np.array(sorted(class_weight.keys()))
1030 values = np.array([class_weight[i] for i in keys])
1031 weight_vector = np.zeros(np.max(keys) + 1)
1032 weight_vector[:] = np.nan
1033 weight_vector[keys] = values
1035 y_classes = smart_cond.smart_cond(
1036 len(y.shape.as_list()) == 2 and backend.shape(y)[1] > 1,
1037 lambda: backend.argmax(y, axis=1),
1038 lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64))
1039 class_sample_weight = array_ops.gather(weight_vector, y_classes)
1040 gen_array_ops.check_numerics(
1041 class_sample_weight,
1042 'Invalid classes or class weights detected. NaN values indicate that '
1043 'an appropriate class weight could not be determined.')
1044 class_sample_weight = math_ops.cast(class_sample_weight, backend.floatx())
1045 if sample_weight is not None:
1046 sample_weight = math_ops.cast(
1047 tensor_conversion.convert_to_tensor_v2_with_dispatch(sample_weight),
1048 backend.floatx(),
1049 )
1050 else:
1051 y_classes = y
1052 if len(y.shape) == 2:
1053 if y.shape[1] > 1:
1054 y_classes = np.argmax(y, axis=1)
1055 elif y.shape[1] == 1:
1056 y_classes = np.reshape(y, y.shape[0])
1058 class_sample_weight = np.asarray(
1059 [class_weight[cls] for cls in y_classes if cls in class_weight])
1061 if len(class_sample_weight) != len(y_classes):
1062 # subtract the sets to pick all missing classes
1063 existing_classes = set(y_classes)
1064 existing_class_weight = set(class_weight.keys())
1065 raise ValueError(
1066 '`class_weight` must contain all classes in the data.'
1067 ' The classes %s exist in the data but not in '
1068 '`class_weight`.' % (existing_classes - existing_class_weight))
1070 if class_sample_weight is not None and sample_weight is not None:
1071 # Multiply weights if both are provided.
1072 return class_sample_weight * sample_weight
1073 if sample_weight is not None:
1074 return sample_weight
1075 if class_sample_weight is not None:
1076 return class_sample_weight
1077 return None
1080def has_symbolic_tensors(ls):
1081 if context.executing_eagerly():
1082 return False
1083 return has_tensors(ls)
1086def has_tensors(ls):
1087 """Returns true if `ls` contains tensors."""
1088 # Note: at some point in time ragged tensors didn't count as tensors, so this
1089 # returned false for ragged tensors. Making this return true fails some tests
1090 # which would then require a steps_per_epoch argument.
1091 if isinstance(ls, (list, tuple)):
1092 return any(
1093 tensor_util.is_tf_type(v) and
1094 not isinstance(v, ragged_tensor.RaggedTensor) for v in ls)
1095 if isinstance(ls, dict):
1096 return any(
1097 tensor_util.is_tf_type(v) and
1098 not isinstance(v, ragged_tensor.RaggedTensor)
1099 for _, v in ls.items())
1100 return tensor_util.is_tf_type(ls) and not isinstance(
1101 ls, ragged_tensor.RaggedTensor)
1104def get_metric_name(metric, weighted=False):
1105 """Returns the name corresponding to the given metric input.
1107 Args:
1108 metric: Metric function name or reference.
1109 weighted: Boolean indicating if the given metric is weighted.
1111 Returns:
1112 The metric name.
1113 """
1114 if tf2.enabled():
1115 # We keep the string that the user has set in compile as the metric name.
1116 if isinstance(metric, str):
1117 return metric
1119 metric = metrics_module.get(metric)
1120 return metric.name if hasattr(metric, 'name') else metric.__name__
1121 else:
1122 metric_name_prefix = 'weighted_' if weighted else ''
1123 if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
1124 if metric in ('accuracy', 'acc'):
1125 suffix = 'acc'
1126 elif metric in ('crossentropy', 'ce'):
1127 suffix = 'ce'
1128 else:
1129 metric_fn = metrics_module.get(metric)
1130 # Get metric name as string
1131 if hasattr(metric_fn, 'name'):
1132 suffix = metric_fn.name
1133 else:
1134 suffix = metric_fn.__name__
1135 metric_name = metric_name_prefix + suffix
1136 return metric_name
1139def get_metric_function(metric, output_shape=None, loss_fn=None):
1140 """Returns the metric function corresponding to the given metric input.
1142 Args:
1143 metric: Metric function name or reference.
1144 output_shape: The shape of the output that this metric will be calculated
1145 for.
1146 loss_fn: The loss function used.
1148 Returns:
1149 The metric function.
1150 """
1151 if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']:
1152 return metrics_module.get(metric)
1154 is_sparse_categorical_crossentropy = (
1155 isinstance(loss_fn, losses.SparseCategoricalCrossentropy) or
1156 (isinstance(loss_fn, losses.LossFunctionWrapper) and
1157 loss_fn.fn == losses.sparse_categorical_crossentropy))
1159 is_binary_crossentropy = (
1160 isinstance(loss_fn, losses.BinaryCrossentropy) or
1161 (isinstance(loss_fn, losses.LossFunctionWrapper) and
1162 loss_fn.fn == losses.binary_crossentropy))
1164 if metric in ['accuracy', 'acc']:
1165 if output_shape[-1] == 1 or is_binary_crossentropy:
1166 return metrics_module.binary_accuracy
1167 elif is_sparse_categorical_crossentropy:
1168 return metrics_module.sparse_categorical_accuracy
1169 # If the output_shape[-1] is not 1, then we know output is `categorical`.
1170 # We assume it is sparse categorical only if loss is explicitly given
1171 # as sparse categorical crossentropy loss.
1172 return metrics_module.categorical_accuracy
1173 else:
1174 if output_shape[-1] == 1 or is_binary_crossentropy:
1175 return metrics_module.binary_crossentropy
1176 elif is_sparse_categorical_crossentropy:
1177 return metrics_module.sparse_categorical_crossentropy
1178 return metrics_module.categorical_crossentropy
1181def call_metric_function(metric_fn,
1182 y_true,
1183 y_pred=None,
1184 weights=None,
1185 mask=None):
1186 """Invokes metric function and returns the metric result tensor."""
1187 if mask is not None:
1188 mask = math_ops.cast(mask, y_pred.dtype)
1189 if weights is None:
1190 # Use mask as sample weight.
1191 weights = mask
1192 else:
1193 # Update dimensions of weights to match with mask.
1194 weights = math_ops.cast(weights, dtype=y_pred.dtype)
1195 mask, _, weights = losses_utils.squeeze_or_expand_dimensions(
1196 mask, sample_weight=weights)
1197 weights *= mask
1199 if y_pred is not None:
1200 return metric_fn(y_true, y_pred, sample_weight=weights)
1201 # `Mean` metric only takes a single value.
1202 return metric_fn(y_true, sample_weight=weights)
1205def get_loss_function(loss):
1206 """Returns the loss corresponding to the loss input in `compile` API."""
1207 if loss is None or isinstance(loss, losses.Loss):
1208 return loss
1210 if tf_inspect.isclass(loss) and issubclass(loss, losses.Loss):
1211 # It is not safe to assume that the loss takes no constructor arguments.
1212 raise ValueError(
1213 'Received uninstantiated Loss class: {}\nPlease call loss ""classes '
1214 'before passing them to Model.compile.'.format(loss))
1216 # Deserialize loss configuration, if needed.
1217 if isinstance(loss, collections.abc.Mapping):
1218 loss = losses.get(loss)
1220 # Custom callable class.
1221 if callable(loss) and not hasattr(loss, '__name__'):
1222 return loss
1224 # Wrap loss function with signature `(y_true, y_pred, **kwargs)`
1225 # in `LossFunctionWrapper` class.
1226 loss_fn = losses.get(loss)
1228 # For losses which are given as strings/functions in the compile API,
1229 # we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE`
1230 # (both in distribution strategy context and otherwise).
1231 return losses.LossFunctionWrapper(
1232 loss_fn,
1233 name=loss_fn.__name__,
1234 reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE)
1237def validate_dataset_input(x, y, sample_weight, validation_split=None):
1238 """Validates user input arguments when a dataset iterator is passed.
1240 Args:
1241 x: Input data. A `tf.data` dataset or iterator.
1242 y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s).
1243 Expected to be `None` when `x` is a dataset iterator.
1244 sample_weight: An optional sample-weight array passed by the user to weight
1245 the importance of each sample in `x`. Expected to be `None` when `x` is a
1246 dataset iterator
1247 validation_split: Float between 0 and 1. Fraction of the training data to be
1248 used as validation data. Expected to be `None` when `x` is a dataset
1249 iterator.
1251 Raises:
1252 ValueError: if argument `y` or `sample_weight` or `validation_split` are
1253 provided by user.
1254 """
1255 if y is not None:
1256 raise ValueError('You passed a dataset or dataset iterator (%s) as '
1257 'input `x` to your model. In that case, you should '
1258 'not specify a target (`y`) argument, since the dataset '
1259 'or dataset iterator generates both input data and '
1260 'target data. '
1261 'Received: %s' % (x, y))
1262 if sample_weight is not None:
1263 raise ValueError('`sample_weight` argument is not supported when input '
1264 '`x` is a dataset or a dataset iterator. Instead, you'
1265 'can provide sample_weight as the third element of your'
1266 'dataset, i.e. (inputs, targets, sample_weight). '
1267 'Received: x=%s, sample_weight=%s' % (x, sample_weight))
1268 if validation_split is not None and validation_split != 0.0:
1269 raise ValueError(
1270 '`validation_split` argument is not supported when '
1271 'input `x` is a dataset or a dataset iterator. '
1272 'Received: x=%s, validation_split=%f' % (x, validation_split))
1275def validate_input_types(inp, orig_inp, allow_dict=True, field_name='inputs'):
1276 """Helper function to validate either inputs or targets."""
1277 if isinstance(inp, (list, tuple)):
1278 if not all(isinstance(v, np.ndarray) or
1279 tensor_util.is_tf_type(v) for v in inp):
1280 raise ValueError(
1281 'Please provide as model inputs either a single array or a list of '
1282 'arrays. You passed: {}={}'.format(field_name, str(orig_inp)))
1283 elif isinstance(inp, dict):
1284 if not allow_dict:
1285 raise ValueError(
1286 'You cannot pass a dictionary as model {}.'.format(field_name))
1287 elif not isinstance(inp, np.ndarray) and not tensor_util.is_tf_type(inp):
1288 raise ValueError(
1289 'Please provide as model inputs either a single array or a list of '
1290 'arrays. You passed: {}={}'.format(field_name, orig_inp))
1293def check_generator_arguments(y=None, sample_weight=None,
1294 validation_split=None):
1295 """Validates arguments passed when using a generator."""
1296 if y is not None:
1297 raise ValueError('`y` argument is not supported when data is'
1298 'a generator or Sequence instance. Instead pass targets'
1299 ' as the second element of the generator.')
1300 if sample_weight is not None:
1301 raise ValueError('`sample_weight` argument is not supported when data is'
1302 'a generator or Sequence instance. Instead pass sample'
1303 ' weights as the third element of the generator.')
1304 if validation_split:
1305 raise ValueError('If your data is in the form of a Python generator, '
1306 'you cannot use `validation_split`.')
1309def check_steps_argument(input_data, steps, steps_name):
1310 """Validates `steps` argument based on input data's type.
1312 The cases when `steps` value must be provided are when
1313 1. input data passed is an iterator.
1314 2. model was built on top of symbolic tensors, input data is not
1315 required and is `None`.
1316 3. input data passed is a symbolic tensor.
1318 Args:
1319 input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or
1320 tf.data.Dataset iterator or `None`.
1321 steps: Integer or `None`. Total number of steps (batches of samples) to
1322 execute.
1323 steps_name: The public API's parameter name for `steps`.
1325 Returns:
1326 boolean, True if `steps` argument is required, else False.
1328 Raises:
1329 ValueError: if `steps` argument is required for given input data type
1330 but not provided.
1331 """
1332 is_x_iterator = isinstance(
1333 input_data, (iterator_ops.Iterator, iterator_ops.IteratorBase))
1334 if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or
1335 (isinstance(input_data, list) and not input_data)):
1336 if steps is None:
1337 input_type_str = 'a Dataset iterator' if is_x_iterator else 'data tensors'
1338 raise ValueError('When using {input_type} as input to a model, you should'
1339 ' specify the `{steps_name}` argument.'.format(
1340 input_type=input_type_str, steps_name=steps_name))
1341 return True
1343 if isinstance(input_data, (data_types.DatasetV1, data_types.DatasetV2)):
1344 return True
1346 if steps is not None:
1347 list_types = (np.ndarray, list, tuple)
1348 if (isinstance(input_data, list_types) or
1349 (isinstance(input_data, dict) and
1350 any(isinstance(v, list_types) for v in input_data.values()))):
1351 logging.warning('When passing input data as arrays, do not specify '
1352 '`steps_per_epoch`/`steps` argument. '
1353 'Please use `batch_size` instead.')
1354 return False
1357def cast_single_tensor(x, dtype=None):
1358 if isinstance(x, np.ndarray):
1359 x = tensor_conversion.convert_to_tensor_v2_with_dispatch(x)
1360 dtype = dtype or backend.floatx()
1361 if x.dtype.is_floating:
1362 return math_ops.cast(x, dtype=dtype)
1363 return x
1366def cast_if_floating_dtype_and_mismatch(targets, outputs):
1367 """Returns target data tensors using correct datatype.
1369 Checks that each target and output pair are the same datatype. If not, casts
1370 the target to the output's datatype.
1372 Args:
1373 targets: tensor or list of targets.
1374 outputs: tensor or list of outputs.
1376 Returns:
1377 Targets in appropriate datatype.
1378 """
1379 if tensor_util.is_tf_type(targets):
1380 # There is one target, so output[0] should be the only output.
1381 return cast_single_tensor(targets, dtype=outputs[0].dtype)
1382 new_targets = []
1383 for target, out in zip(targets, outputs):
1384 if isinstance(target, np.ndarray):
1385 target = tensor_conversion.convert_to_tensor_v2_with_dispatch(target)
1386 if target.dtype != out.dtype:
1387 new_targets.append(cast_single_tensor(target, dtype=out.dtype))
1388 else:
1389 new_targets.append(target)
1390 return new_targets
1393def cast_if_floating_dtype(x, dtype=None):
1394 """Casts the given data tensors to the default floating point type.
1396 Casts only if the input is already a floating point type.
1397 Args:
1398 x: tensor or list/tuple of tensors.
1399 dtype: The dtype to which Tensors should be cast.
1401 Returns:
1402 Converted input.
1403 """
1404 return nest.map_structure(functools.partial(cast_single_tensor, dtype=dtype),
1405 x)
1408def cast_to_model_input_dtypes(x, model):
1409 """Casts the given data tensors to the dtypes of the model inputs.
1411 Args:
1412 x: tensor or list/tuple of tensors.
1413 model: The model.
1415 Returns:
1416 Converted input. Each tensor is casted to the corresponding input in
1417 `model.inputs`.
1418 """
1419 input_dtypes = nest.map_structure(lambda t: t.dtype, model.inputs)
1420 return nest.map_structure(math_ops.cast, x, input_dtypes)
1423def prepare_sample_weight_modes(training_endpoints, sample_weight_mode):
1424 """Prepares sample weight modes for the model.
1426 Args:
1427 training_endpoints: List of model _TrainingEndpoints.
1428 sample_weight_mode: sample weight mode user input passed from compile API.
1430 Raises:
1431 ValueError: In case of invalid `sample_weight_mode` input.
1432 """
1434 if isinstance(sample_weight_mode, collections.abc.Mapping):
1435 generic_utils.check_for_unexpected_keys(
1436 'sample_weight_mode', sample_weight_mode,
1437 [e.output_name for e in training_endpoints])
1439 for end_point in training_endpoints:
1440 if not end_point.should_skip_target_weights():
1441 if end_point.output_name not in sample_weight_mode:
1442 raise ValueError('Output ' + end_point.output_name +
1443 'missing from `_sample_weight_modes` dictionary')
1444 else:
1445 end_point.sample_weight_mode = sample_weight_mode.get(
1446 end_point.output_name)
1447 elif isinstance(sample_weight_mode, (list, tuple)):
1448 if len(sample_weight_mode) != len(training_endpoints):
1449 raise ValueError('When passing a list as sample_weight_mode, '
1450 'it should have one entry per model output. '
1451 'The model has ' + str(len(training_endpoints)) +
1452 ' outputs, but you passed ' +
1453 str(len(sample_weight_mode)) + '_sample_weight_modes.')
1454 for mode, endpoint in zip(sample_weight_mode, training_endpoints):
1455 if not endpoint.should_skip_target_weights():
1456 endpoint.sample_weight_mode = mode
1457 else:
1458 for endpoint in training_endpoints:
1459 if not endpoint.should_skip_target_weights():
1460 endpoint.sample_weight_mode = sample_weight_mode
1463def prepare_loss_functions(loss, output_names):
1464 """Converts loss to a list of loss functions.
1466 Args:
1467 loss: String (name of objective function), objective function or
1468 `tf.losses.Loss` instance. See `tf.losses`. If the model has multiple
1469 outputs, you can use a different loss on each output by passing a
1470 dictionary or a list of losses. The loss value that will be minimized by
1471 the model will then be the sum of all individual losses.
1472 output_names: List of model output names.
1474 Returns:
1475 A list of loss objective functions.
1477 Raises:
1478 ValueError: If loss is a dict with keys not in model output names,
1479 or if loss is a list with len not equal to model outputs.
1480 """
1481 if isinstance(loss, collections.abc.Mapping):
1482 generic_utils.check_for_unexpected_keys('loss', loss, output_names)
1483 loss_functions = []
1484 for name in output_names:
1485 if name not in loss:
1486 logging.warning(
1487 'Output {0} missing from loss dictionary. We assume '
1488 'this was done on purpose. The fit and evaluate APIs will not be '
1489 'expecting any data to be passed to {0}.'.format(name))
1490 loss_functions.append(get_loss_function(loss.get(name, None)))
1491 elif isinstance(loss, str):
1492 loss_functions = [get_loss_function(loss) for _ in output_names]
1493 elif isinstance(loss, collections.abc.Sequence):
1494 if len(loss) != len(output_names):
1495 raise ValueError('When passing a list as loss, it should have one entry '
1496 'per model outputs. The model has {} outputs, but you '
1497 'passed loss={}'.format(len(output_names), loss))
1498 loss_functions = nest.map_structure(get_loss_function, loss)
1499 else:
1500 loss_functions = [get_loss_function(loss) for _ in range(len(output_names))]
1502 return loss_functions
1505def prepare_loss_weights(training_endpoints, loss_weights=None):
1506 """Converts loss weights to a list of loss weights.
1508 The result loss weights will be populated on the training endpoint.
1510 Args:
1511 training_endpoints: List of model training endpoints.
1512 loss_weights: Optional list or dictionary specifying scalar coefficients
1513 (Python floats) to weight the loss contributions of different model
1514 outputs. The loss value that will be minimized by the model will then be
1515 the *weighted sum* of all individual losses, weighted by the
1516 `loss_weights` coefficients. If a list, it is expected to have a 1:1
1517 mapping to the model's outputs. If a dict, it is expected to map
1518 output names (strings) to scalar coefficients.
1520 Raises:
1521 ValueError: If loss weight is a dict with key not in model output names,
1522 or if loss is a list with len not equal to model outputs.
1523 """
1524 if loss_weights is None:
1525 for e in training_endpoints:
1526 e.loss_weight = 1.
1527 elif isinstance(loss_weights, collections.abc.Mapping):
1528 generic_utils.check_for_unexpected_keys(
1529 'loss_weights', loss_weights,
1530 [e.output_name for e in training_endpoints])
1531 for e in training_endpoints:
1532 e.loss_weight = loss_weights.get(e.output_name, 1.)
1533 elif isinstance(loss_weights, list):
1534 if len(loss_weights) != len(training_endpoints):
1535 raise ValueError('When passing a list as loss_weights, '
1536 'it should have one entry per model output. '
1537 'The model has ' + str(len(training_endpoints)) +
1538 ' outputs, but you passed loss_weights=' +
1539 str(loss_weights))
1540 for w, e in zip(loss_weights, training_endpoints):
1541 e.loss_weight = w
1542 else:
1543 raise TypeError('Could not interpret loss_weights argument: ' +
1544 str(loss_weights) + ' - expected a list of dicts.')
1547# TODO(rohanj): This is a hack to get around not depending on feature_column and
1548# create a cyclical dependency. Figure out a cleaner solution
1549def is_feature_layer(layer):
1550 """Returns whether `layer` is a FeatureLayer or not."""
1551 return getattr(layer, '_is_feature_layer', False)
1554def is_eager_dataset_or_iterator(data):
1555 return context.executing_eagerly() and isinstance(
1556 data, (data_types.DatasetV1, data_types.DatasetV2,
1557 iterator_ops.IteratorBase))
1560# pylint: disable=protected-access
1561def get_dataset_graph_def(dataset):
1562 if context.executing_eagerly():
1563 graph_def_str = dataset._as_serialized_graph().numpy()
1564 else:
1565 graph_def_str = backend.get_value(dataset._as_serialized_graph())
1566 return graph_pb2.GraphDef().FromString(graph_def_str)
1569def verify_dataset_shuffled(x):
1570 """Verifies that the dataset is shuffled.
1572 Args:
1573 x: Dataset passed as an input to the model.
1575 Returns:
1576 boolean, whether the input dataset is shuffled or not.
1577 """
1578 assert isinstance(x, data_types.DatasetV2)
1579 graph_def = get_dataset_graph_def(x)
1580 for node in graph_def.node:
1581 if node.op.startswith('ShuffleDataset'):
1582 return True
1583 # Also check graph_def.library.function for ds.interleave or ds.flat_map
1584 for function in graph_def.library.function:
1585 for node in function.node_def:
1586 if node.op.startswith('ShuffleDataset'):
1587 return True
1588 logging.warning('Expected a shuffled dataset but input dataset `x` is '
1589 'not shuffled. Please invoke `shuffle()` on input dataset.')
1590 return False
1593def is_dataset_or_iterator(data):
1594 return isinstance(data, (data_types.DatasetV1, data_types.DatasetV2,
1595 iterator_ops.Iterator, iterator_ops.IteratorBase))
1598def get_iterator(dataset):
1599 """Create and initialize an iterator from a dataset."""
1600 if context.executing_eagerly():
1601 iterator = dataset_ops.make_one_shot_iterator(dataset)
1602 else:
1603 iterator = dataset_ops.make_initializable_iterator(dataset)
1604 initialize_iterator(iterator)
1605 return iterator
1608def initialize_iterator(iterator):
1609 if not context.executing_eagerly():
1610 init_op = iterator.initializer
1611 backend.get_session((init_op,)).run(init_op)
1614def extract_tensors_from_dataset(dataset):
1615 """Extract a tuple of tensors `inputs, targets, sample_weight` from a dataset.
1617 Args:
1618 dataset: Dataset instance.
1620 Returns:
1621 Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
1622 """
1623 iterator = get_iterator(dataset)
1624 inputs, targets, sample_weight = unpack_iterator_input(iterator)
1625 return inputs, targets, sample_weight
1628def unpack_iterator_input(iterator):
1629 """Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`.
1631 Args:
1632 iterator: Instance of a dataset iterator.
1634 Returns:
1635 Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
1636 """
1637 try:
1638 next_element = iterator.get_next()
1639 except errors.OutOfRangeError:
1640 raise RuntimeError('Your dataset iterator ran out of data; '
1641 'Make sure that your dataset can generate '
1642 'required number of samples.')
1644 if isinstance(next_element, (list, tuple)):
1645 if len(next_element) not in [2, 3]:
1646 raise ValueError(
1647 'Please provide model inputs as a list or tuple of 2 or 3 '
1648 'elements: (input, target) or (input, target, sample_weights) '
1649 'Received %s' % next_element)
1650 if len(next_element) == 2:
1651 x, y = next_element
1652 weights = None
1653 else:
1654 x, y, weights = next_element
1655 else:
1656 x = next_element
1657 y = None
1658 weights = None
1659 return x, y, weights
1662def infer_steps_for_dataset(model,
1663 dataset,
1664 steps,
1665 epochs=1,
1666 steps_name='steps'):
1667 """Infers steps_per_epoch needed to loop through a dataset.
1669 Args:
1670 model: Keras model instance.
1671 dataset: Input data of type tf.data.Dataset.
1672 steps: Number of steps to draw from the dataset (may be None if unknown).
1673 epochs: Number of times to iterate over the dataset.
1674 steps_name: The string name of the steps argument, either `steps`,
1675 `validation_steps`, or `steps_per_epoch`. Only used for error message
1676 formatting.
1678 Returns:
1679 Integer or `None`. Inferred number of steps to loop through the dataset.
1680 `None` is returned if 1) the size of the dataset is unknown and `steps` was
1681 not specified, or 2) this is multi-worker training and auto sharding is
1682 enabled.
1684 Raises:
1685 ValueError: In case of invalid argument values.
1686 """
1687 assert isinstance(dataset, data_types.DatasetV2)
1688 if (model._in_multi_worker_mode() and
1689 (dataset.options().experimental_distribute.auto_shard_policy !=
1690 options_lib.AutoShardPolicy.OFF)):
1691 # If the dataset would be auto-sharded, we should not infer a local
1692 # steps_per_epoch due to the possible inbalanced sharding between workers.
1693 return None
1695 size = backend.get_value(cardinality.cardinality(dataset))
1696 if size == cardinality.INFINITE and steps is None:
1697 raise ValueError('When passing an infinitely repeating dataset, you '
1698 'must specify the `%s` argument.' % (steps_name,))
1699 if size >= 0:
1700 if steps is not None and steps * epochs > size:
1701 if epochs > 1:
1702 raise ValueError('The dataset you passed contains %s batches, but you '
1703 'passed `epochs=%s` and `%s=%s`, which is a total of '
1704 '%s steps. We cannot draw that many steps from this '
1705 'dataset. We suggest to set `%s=%s`.' %
1706 (size, epochs, steps_name, steps, steps * epochs,
1707 steps_name, size // epochs))
1708 else:
1709 raise ValueError('The dataset you passed contains %s batches, but you '
1710 'passed `%s=%s`. We cannot draw that many steps from '
1711 'this dataset. We suggest to set `%s=%s`.' %
1712 (size, steps_name, steps, steps_name, size))
1713 if steps is None:
1714 if size >= 0:
1715 return size
1716 return None
1717 return steps
1720class ModelInputs(object):
1721 """Encapsulates model inputs.
1723 Allows for transforming model inputs while keeping the same structure.
1724 """
1726 def __init__(self, inputs):
1727 self._inputs = inputs
1728 self._is_dict = isinstance(self._inputs, dict)
1729 self._is_single_input = not isinstance(self._inputs, (list, tuple, dict))
1731 self._flattened_inputs = []
1732 self._input_names = []
1734 if self._is_dict:
1735 for k in sorted(self._inputs.keys()):
1736 self._flattened_inputs.append(self._inputs[k])
1737 self._input_names.append(k)
1738 else:
1739 self._flattened_inputs = nest.flatten(self._inputs)
1740 self._input_names = [
1741 'input_%d' % (i + 1) for i in range(len(self._flattened_inputs))
1742 ]
1744 def get_input_names(self):
1745 """Returns keys to name inputs by.
1747 In case inputs provided were a list, tuple or single entry, we make up a
1748 key 'input_%d'. For dictionary case, we return a sorted list of keys.
1749 """
1750 return self._input_names
1752 def get_symbolic_inputs(self, return_single_as_list=False):
1753 """Returns inputs to be set as self.inputs for a model."""
1754 # TODO(karmel): There is a side-effect here where what you get
1755 # with as_list and as_dict depends on whether you have called this
1756 # method first, since it modifies in place.
1757 for i, (k, v) in enumerate(zip(self._input_names, self._flattened_inputs)):
1758 if isinstance(v, (list, float, int)):
1759 v = np.asarray(v)
1760 if v.ndim == 1:
1761 v = np.expand_dims(v, 1)
1763 if isinstance(v, np.ndarray):
1764 # We fix the placeholder shape except the batch size.
1765 # This is suboptimal, but it is the best we can do with the info
1766 # we have. The user should call `model._set_inputs(placeholders)`
1767 # to specify custom placeholders if the need arises.
1768 shape = (None,) + tuple(v.shape[1:])
1769 if shape == (None,):
1770 shape = (None, 1)
1771 dtype = dtypes.as_dtype(v.dtype)
1772 if dtype.is_floating:
1773 dtype = backend.floatx()
1774 v = backend.placeholder(shape=shape, name=k, dtype=dtype)
1775 elif isinstance(v, tensor_spec.TensorSpec):
1776 shape = (None,) + tuple(v.shape.as_list()[1:])
1777 if shape == (None,):
1778 shape = (None, 1)
1779 v = backend.placeholder(shape=shape, name=k, dtype=v.dtype)
1781 self._flattened_inputs[i] = v
1783 if self._is_dict:
1784 return dict(zip(self._input_names, self._flattened_inputs))
1785 if self._is_single_input and not return_single_as_list:
1786 return self._flattened_inputs[0]
1787 return self._flattened_inputs
1789 def as_dict(self):
1790 """An iterable over a dictionary version of inputs."""
1791 for k, v in zip(self._input_names, self._flattened_inputs):
1792 yield k, v
1794 def as_list(self):
1795 """Returning the inputs as a list."""
1796 return self._flattened_inputs
1799# Allow use of methods not exposed to the user.
1800# pylint: disable=protected-access
1803# pylint: enable=protected-access
1806def generic_output_names(outputs_list):
1807 return ['output_%d' % (i + 1) for i in range(len(outputs_list))]
1810def should_run_validation(validation_freq, epoch):
1811 """Checks if validation should be run this epoch.
1813 Args:
1814 validation_freq: Integer or list. If an integer, specifies how many training
1815 epochs to run before a new validation run is performed. If a list,
1816 specifies the epochs on which to run validation.
1817 epoch: Integer, the number of the training epoch just completed.
1819 Returns:
1820 Bool, True if validation should be run.
1822 Raises:
1823 ValueError: if `validation_freq` is an Integer and less than 1, or if
1824 it is neither an Integer nor a Sequence.
1825 """
1826 # `epoch` is 0-indexed internally but 1-indexed in the public API.
1827 one_indexed_epoch = epoch + 1
1829 if isinstance(validation_freq, int):
1830 if validation_freq < 1:
1831 raise ValueError('`validation_freq` can not be less than 1.')
1832 return one_indexed_epoch % validation_freq == 0
1834 if not isinstance(validation_freq, collections.abc.Container):
1835 raise ValueError('`validation_freq` must be an Integer or '
1836 '`collections.abc.Container` (e.g. list, tuple, etc.)')
1837 return one_indexed_epoch in validation_freq
1840def split_training_and_validation_data(x, y, sample_weights, validation_split):
1841 """Split input data into train/eval section based on validation_split."""
1842 if has_symbolic_tensors(x):
1843 raise ValueError('If your data is in the form of symbolic tensors, '
1844 'you cannot use `validation_split`.')
1845 if hasattr(x[0], 'shape'):
1846 split_at = int(x[0].shape[0] * (1. - validation_split))
1847 else:
1848 split_at = int(len(x[0]) * (1. - validation_split))
1849 x, val_x = (generic_utils.slice_arrays(x, 0, split_at),
1850 generic_utils.slice_arrays(x, split_at))
1851 y, val_y = (generic_utils.slice_arrays(y, 0, split_at),
1852 generic_utils.slice_arrays(y, split_at))
1853 if sample_weights:
1854 sample_weights, val_sample_weights = (
1855 generic_utils.slice_arrays(sample_weights, 0, split_at),
1856 generic_utils.slice_arrays(sample_weights, split_at),
1857 )
1858 else:
1859 val_sample_weights = None
1860 return x, y, sample_weights, val_x, val_y, val_sample_weights
1863def unpack_validation_data(validation_data, raise_if_ambiguous=True):
1864 """Unpack validation data based input type.
1866 The validation data is not touched if its dataset or dataset iterator.
1867 For other type of input (Numpy or tensor), it will be unpacked into tuple of
1868 3 which is x, y and sample weights.
1870 Args:
1871 validation_data: dataset, dataset iterator, or numpy, tensor tuple.
1872 raise_if_ambiguous: boolean on whether to fail if validation_data cannot be
1873 parsed. Otherwise simply return validation_data, None, None and defer the
1874 decision to the caller.
1876 Returns:
1877 tuple of 3, (x, y, sample_weights) for numpy and tensor input.
1878 """
1879 if (isinstance(validation_data, (iterator_ops.Iterator,
1880 iterator_ops.IteratorBase,
1881 data_types.DatasetV2,
1882 data_utils.Sequence))
1883 or not hasattr(validation_data, '__len__')):
1884 val_x = validation_data
1885 val_y = None
1886 val_sample_weight = None
1887 elif len(validation_data) == 2:
1888 try:
1889 val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
1890 val_sample_weight = None
1891 except ValueError:
1892 val_x, val_y, val_sample_weight = validation_data, None, None
1893 elif len(validation_data) == 3:
1894 try:
1895 val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
1896 except ValueError:
1897 val_x, val_y, val_sample_weight = validation_data, None, None
1898 else:
1899 if raise_if_ambiguous:
1900 raise ValueError(
1901 'When passing a `validation_data` argument, '
1902 'it must contain either 2 items (x_val, y_val), '
1903 'or 3 items (x_val, y_val, val_sample_weights), '
1904 'or alternatively it could be a dataset or a '
1905 'dataset or a dataset iterator. '
1906 'However we received `validation_data=%s`' % validation_data)
1907 val_x, val_y, val_sample_weight = validation_data, None, None
1908 return val_x, val_y, val_sample_weight
1911class TrainingLoop(object):
1912 """TrainingLoop is a wrapper class around the training logic.
1914 This class is trying to encapsulate the different logic of fit/eval/predict
1915 with regard to different data input and model condition.
1917 Note that TrainingLoop is stateless, which means it doesn't contain any
1918 internal field and can be reused with different model and inputs.
1919 """
1921 def fit(self,
1922 model,
1923 x=None,
1924 y=None,
1925 batch_size=None,
1926 epochs=1,
1927 verbose=1,
1928 callbacks=None,
1929 validation_split=0.,
1930 validation_data=None,
1931 shuffle=True,
1932 class_weight=None,
1933 sample_weight=None,
1934 initial_epoch=0,
1935 steps_per_epoch=None,
1936 validation_steps=None,
1937 validation_freq=1,
1938 **kwargs):
1939 """Train the model with the inputs and targets."""
1940 raise NotImplementedError()
1942 def evaluate(self,
1943 model,
1944 x=None,
1945 y=None,
1946 batch_size=None,
1947 verbose=1,
1948 sample_weight=None,
1949 steps=None,
1950 callbacks=None,
1951 **kwargs):
1952 """Returns the loss value & metrics values for the model in test mode."""
1953 raise NotImplementedError()
1955 def predict(self,
1956 model,
1957 x,
1958 batch_size=None,
1959 verbose=0,
1960 steps=None,
1961 callbacks=None,
1962 **kwargs):
1963 raise NotImplementedError()