Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/generic_utils.py: 21%
258 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python utilities required by Keras."""
17import binascii
18import codecs
19import importlib
20import marshal
21import os
22import re
23import sys
24import time
25import types as python_types
27import numpy as np
28import tensorflow.compat.v2 as tf
30from keras.src.utils import io_utils
31from keras.src.utils import tf_inspect
33# isort: off
34from tensorflow.python.util.tf_export import keras_export
37def func_dump(func):
38 """Serializes a user defined function.
40 Args:
41 func: the function to serialize.
43 Returns:
44 A tuple `(code, defaults, closure)`.
45 """
46 if os.name == "nt":
47 raw_code = marshal.dumps(func.__code__).replace(b"\\", b"/")
48 code = codecs.encode(raw_code, "base64").decode("ascii")
49 else:
50 raw_code = marshal.dumps(func.__code__)
51 code = codecs.encode(raw_code, "base64").decode("ascii")
52 defaults = func.__defaults__
53 if func.__closure__:
54 closure = tuple(c.cell_contents for c in func.__closure__)
55 else:
56 closure = None
57 return code, defaults, closure
60def func_load(code, defaults=None, closure=None, globs=None):
61 """Deserializes a user defined function.
63 Args:
64 code: bytecode of the function.
65 defaults: defaults of the function.
66 closure: closure of the function.
67 globs: dictionary of global objects.
69 Returns:
70 A function object.
71 """
72 if isinstance(code, (tuple, list)): # unpack previous dump
73 code, defaults, closure = code
74 if isinstance(defaults, list):
75 defaults = tuple(defaults)
77 def ensure_value_to_cell(value):
78 """Ensures that a value is converted to a python cell object.
80 Args:
81 value: Any value that needs to be casted to the cell type
83 Returns:
84 A value wrapped as a cell object (see function "func_load")
85 """
87 def dummy_fn():
89 value # just access it so it gets captured in .__closure__
91 cell_value = dummy_fn.__closure__[0]
92 if not isinstance(value, type(cell_value)):
93 return cell_value
94 return value
96 if closure is not None:
97 closure = tuple(ensure_value_to_cell(_) for _ in closure)
98 try:
99 raw_code = codecs.decode(code.encode("ascii"), "base64")
100 except (UnicodeEncodeError, binascii.Error):
101 raw_code = code.encode("raw_unicode_escape")
102 code = marshal.loads(raw_code)
103 if globs is None:
104 globs = globals()
105 return python_types.FunctionType(
106 code, globs, name=code.co_name, argdefs=defaults, closure=closure
107 )
110def has_arg(fn, name, accept_all=False):
111 """Checks if a callable accepts a given keyword argument.
113 Args:
114 fn: Callable to inspect.
115 name: Check if `fn` can be called with `name` as a keyword argument.
116 accept_all: What to return if there is no parameter called `name` but
117 the function accepts a `**kwargs` argument.
119 Returns:
120 bool, whether `fn` accepts a `name` keyword argument.
121 """
122 arg_spec = tf_inspect.getfullargspec(fn)
123 if accept_all and arg_spec.varkw is not None:
124 return True
125 return name in arg_spec.args or name in arg_spec.kwonlyargs
128@keras_export("keras.utils.Progbar")
129class Progbar:
130 """Displays a progress bar.
132 Args:
133 target: Total number of steps expected, None if unknown.
134 width: Progress bar width on screen.
135 verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
136 stateful_metrics: Iterable of string names of metrics that should *not*
137 be averaged over time. Metrics in this list will be displayed as-is.
138 All others will be averaged by the progbar before display.
139 interval: Minimum visual progress update interval (in seconds).
140 unit_name: Display name for step counts (usually "step" or "sample").
141 """
143 def __init__(
144 self,
145 target,
146 width=30,
147 verbose=1,
148 interval=0.05,
149 stateful_metrics=None,
150 unit_name="step",
151 ):
152 self.target = target
153 self.width = width
154 self.verbose = verbose
155 self.interval = interval
156 self.unit_name = unit_name
157 if stateful_metrics:
158 self.stateful_metrics = set(stateful_metrics)
159 else:
160 self.stateful_metrics = set()
162 self._dynamic_display = (
163 (hasattr(sys.stdout, "isatty") and sys.stdout.isatty())
164 or "ipykernel" in sys.modules
165 or "posix" in sys.modules
166 or "PYCHARM_HOSTED" in os.environ
167 )
168 self._total_width = 0
169 self._seen_so_far = 0
170 # We use a dict + list to avoid garbage collection
171 # issues found in OrderedDict
172 self._values = {}
173 self._values_order = []
174 self._start = time.time()
175 self._last_update = 0
176 self._time_at_epoch_start = self._start
177 self._time_at_epoch_end = None
178 self._time_after_first_step = None
180 def update(self, current, values=None, finalize=None):
181 """Updates the progress bar.
183 Args:
184 current: Index of current step.
185 values: List of tuples: `(name, value_for_last_step)`. If `name` is
186 in `stateful_metrics`, `value_for_last_step` will be displayed
187 as-is. Else, an average of the metric over time will be
188 displayed.
189 finalize: Whether this is the last update for the progress bar. If
190 `None`, defaults to `current >= self.target`.
191 """
192 if finalize is None:
193 if self.target is None:
194 finalize = False
195 else:
196 finalize = current >= self.target
198 values = values or []
199 for k, v in values:
200 if k not in self._values_order:
201 self._values_order.append(k)
202 if k not in self.stateful_metrics:
203 # In the case that progress bar doesn't have a target value in
204 # the first epoch, both on_batch_end and on_epoch_end will be
205 # called, which will cause 'current' and 'self._seen_so_far' to
206 # have the same value. Force the minimal value to 1 here,
207 # otherwise stateful_metric will be 0s.
208 value_base = max(current - self._seen_so_far, 1)
209 if k not in self._values:
210 self._values[k] = [v * value_base, value_base]
211 else:
212 self._values[k][0] += v * value_base
213 self._values[k][1] += value_base
214 else:
215 # Stateful metrics output a numeric value. This representation
216 # means "take an average from a single value" but keeps the
217 # numeric formatting.
218 self._values[k] = [v, 1]
219 self._seen_so_far = current
221 message = ""
222 now = time.time()
223 info = f" - {now - self._start:.0f}s"
224 if current == self.target:
225 self._time_at_epoch_end = now
226 if self.verbose == 1:
227 if now - self._last_update < self.interval and not finalize:
228 return
230 prev_total_width = self._total_width
231 if self._dynamic_display:
232 message += "\b" * prev_total_width
233 message += "\r"
234 else:
235 message += "\n"
237 if self.target is not None:
238 numdigits = int(np.log10(self.target)) + 1
239 bar = ("%" + str(numdigits) + "d/%d [") % (current, self.target)
240 prog = float(current) / self.target
241 prog_width = int(self.width * prog)
242 if prog_width > 0:
243 bar += "=" * (prog_width - 1)
244 if current < self.target:
245 bar += ">"
246 else:
247 bar += "="
248 bar += "." * (self.width - prog_width)
249 bar += "]"
250 else:
251 bar = "%7d/Unknown" % current
253 self._total_width = len(bar)
254 message += bar
256 time_per_unit = self._estimate_step_duration(current, now)
258 if self.target is None or finalize:
259 info += self._format_time(time_per_unit, self.unit_name)
260 else:
261 eta = time_per_unit * (self.target - current)
262 if eta > 3600:
263 eta_format = "%d:%02d:%02d" % (
264 eta // 3600,
265 (eta % 3600) // 60,
266 eta % 60,
267 )
268 elif eta > 60:
269 eta_format = "%d:%02d" % (eta // 60, eta % 60)
270 else:
271 eta_format = "%ds" % eta
273 info = f" - ETA: {eta_format}"
275 for k in self._values_order:
276 info += f" - {k}:"
277 if isinstance(self._values[k], list):
278 avg = np.mean(
279 self._values[k][0] / max(1, self._values[k][1])
280 )
281 if abs(avg) > 1e-3:
282 info += f" {avg:.4f}"
283 else:
284 info += f" {avg:.4e}"
285 else:
286 info += f" {self._values[k]}"
288 self._total_width += len(info)
289 if prev_total_width > self._total_width:
290 info += " " * (prev_total_width - self._total_width)
292 if finalize:
293 info += "\n"
295 message += info
296 io_utils.print_msg(message, line_break=False)
297 message = ""
299 elif self.verbose == 2:
300 if finalize:
301 numdigits = int(np.log10(self.target)) + 1
302 count = ("%" + str(numdigits) + "d/%d") % (current, self.target)
303 info = count + info
304 for k in self._values_order:
305 info += f" - {k}:"
306 avg = np.mean(
307 self._values[k][0] / max(1, self._values[k][1])
308 )
309 if avg > 1e-3:
310 info += f" {avg:.4f}"
311 else:
312 info += f" {avg:.4e}"
313 if self._time_at_epoch_end:
314 time_per_epoch = (
315 self._time_at_epoch_end - self._time_at_epoch_start
316 )
317 avg_time_per_step = time_per_epoch / self.target
318 self._time_at_epoch_start = now
319 self._time_at_epoch_end = None
320 info += " -" + self._format_time(time_per_epoch, "epoch")
321 info += " -" + self._format_time(
322 avg_time_per_step, self.unit_name
323 )
324 info += "\n"
325 message += info
326 io_utils.print_msg(message, line_break=False)
327 message = ""
329 self._last_update = now
331 def add(self, n, values=None):
332 self.update(self._seen_so_far + n, values)
334 def _format_time(self, time_per_unit, unit_name):
335 """format a given duration to display to the user.
337 Given the duration, this function formats it in either milliseconds
338 or seconds and displays the unit (i.e. ms/step or s/epoch)
339 Args:
340 time_per_unit: the duration to display
341 unit_name: the name of the unit to display
342 Returns:
343 a string with the correctly formatted duration and units
344 """
345 formatted = ""
346 if time_per_unit >= 1 or time_per_unit == 0:
347 formatted += f" {time_per_unit:.0f}s/{unit_name}"
348 elif time_per_unit >= 1e-3:
349 formatted += f" {time_per_unit * 1000.0:.0f}ms/{unit_name}"
350 else:
351 formatted += f" {time_per_unit * 1000000.0:.0f}us/{unit_name}"
352 return formatted
354 def _estimate_step_duration(self, current, now):
355 """Estimate the duration of a single step.
357 Given the step number `current` and the corresponding time `now` this
358 function returns an estimate for how long a single step takes. If this
359 is called before one step has been completed (i.e. `current == 0`) then
360 zero is given as an estimate. The duration estimate ignores the duration
361 of the (assumed to be non-representative) first step for estimates when
362 more steps are available (i.e. `current>1`).
364 Args:
365 current: Index of current step.
366 now: The current time.
368 Returns: Estimate of the duration of a single step.
369 """
370 if current:
371 # there are a few special scenarios here:
372 # 1) somebody is calling the progress bar without ever supplying
373 # step 1
374 # 2) somebody is calling the progress bar and supplies step one
375 # multiple times, e.g. as part of a finalizing call
376 # in these cases, we just fall back to the simple calculation
377 if self._time_after_first_step is not None and current > 1:
378 time_per_unit = (now - self._time_after_first_step) / (
379 current - 1
380 )
381 else:
382 time_per_unit = (now - self._start) / current
384 if current == 1:
385 self._time_after_first_step = now
386 return time_per_unit
387 else:
388 return 0
390 def _update_stateful_metrics(self, stateful_metrics):
391 self.stateful_metrics = self.stateful_metrics.union(stateful_metrics)
394def make_batches(size, batch_size):
395 """Returns a list of batch indices (tuples of indices).
397 Args:
398 size: Integer, total size of the data to slice into batches.
399 batch_size: Integer, batch size.
401 Returns:
402 A list of tuples of array indices.
403 """
404 num_batches = int(np.ceil(size / float(batch_size)))
405 return [
406 (i * batch_size, min(size, (i + 1) * batch_size))
407 for i in range(0, num_batches)
408 ]
411def slice_arrays(arrays, start=None, stop=None):
412 """Slice an array or list of arrays.
414 This takes an array-like, or a list of
415 array-likes, and outputs:
416 - arrays[start:stop] if `arrays` is an array-like
417 - [x[start:stop] for x in arrays] if `arrays` is a list
419 Can also work on list/array of indices: `slice_arrays(x, indices)`
421 Args:
422 arrays: Single array or list of arrays.
423 start: can be an integer index (start index) or a list/array of indices
424 stop: integer (stop index); should be None if `start` was a list.
426 Returns:
427 A slice of the array(s).
429 Raises:
430 ValueError: If the value of start is a list and stop is not None.
431 """
432 if arrays is None:
433 return [None]
434 if isinstance(start, list) and stop is not None:
435 raise ValueError(
436 "The stop argument has to be None if the value of start "
437 f"is a list. Received start={start}, stop={stop}"
438 )
439 elif isinstance(arrays, list):
440 if hasattr(start, "__len__"):
441 # hdf5 datasets only support list objects as indices
442 if hasattr(start, "shape"):
443 start = start.tolist()
444 return [None if x is None else x[start] for x in arrays]
445 return [
446 None
447 if x is None
448 else None
449 if not hasattr(x, "__getitem__")
450 else x[start:stop]
451 for x in arrays
452 ]
453 else:
454 if hasattr(start, "__len__"):
455 if hasattr(start, "shape"):
456 start = start.tolist()
457 return arrays[start]
458 if hasattr(start, "__getitem__"):
459 return arrays[start:stop]
460 return [None]
463def to_list(x):
464 """Normalizes a list/tensor into a list.
466 If a tensor is passed, we return
467 a list of size 1 containing the tensor.
469 Args:
470 x: target object to be normalized.
472 Returns:
473 A list.
474 """
475 if isinstance(x, list):
476 return x
477 return [x]
480def to_snake_case(name):
481 intermediate = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
482 insecure = re.sub("([a-z])([A-Z])", r"\1_\2", intermediate).lower()
483 # If the class is private the name starts with "_" which is not secure
484 # for creating scopes. We prefix the name with "private" in this case.
485 if insecure[0] != "_":
486 return insecure
487 return "private" + insecure
490def is_all_none(structure):
491 iterable = tf.nest.flatten(structure)
492 # We cannot use Python's `any` because the iterable may return Tensors.
493 for element in iterable:
494 if element is not None:
495 return False
496 return True
499def check_for_unexpected_keys(name, input_dict, expected_values):
500 unknown = set(input_dict.keys()).difference(expected_values)
501 if unknown:
502 raise ValueError(
503 f"Unknown entries in {name} dictionary: {list(unknown)}. "
504 f"Only expected following keys: {expected_values}"
505 )
508def validate_kwargs(
509 kwargs, allowed_kwargs, error_message="Keyword argument not understood:"
510):
511 """Checks that all keyword arguments are in the set of allowed keys."""
512 for kwarg in kwargs:
513 if kwarg not in allowed_kwargs:
514 raise TypeError(error_message, kwarg)
517def default(method):
518 """Decorates a method to detect overrides in subclasses."""
519 method._is_default = True
520 return method
523def is_default(method):
524 """Check if a method is decorated with the `default` wrapper."""
525 return getattr(method, "_is_default", False)
528def populate_dict_with_module_objects(target_dict, modules, obj_filter):
529 for module in modules:
530 for name in dir(module):
531 obj = getattr(module, name)
532 if obj_filter(obj):
533 target_dict[name] = obj
536class LazyLoader(python_types.ModuleType):
537 """Lazily import a module, mainly to avoid pulling in large dependencies."""
539 def __init__(self, local_name, parent_module_globals, name):
540 self._local_name = local_name
541 self._parent_module_globals = parent_module_globals
542 super().__init__(name)
544 def _load(self):
545 """Load the module and insert it into the parent's globals."""
546 # Import the target module and insert it into the parent's namespace
547 module = importlib.import_module(self.__name__)
548 self._parent_module_globals[self._local_name] = module
549 # Update this object's dict so that if someone keeps a reference to the
550 # LazyLoader, lookups are efficient (__getattr__ is only called on
551 # lookups that fail).
552 self.__dict__.update(module.__dict__)
553 return module
555 def __getattr__(self, item):
556 module = self._load()
557 return getattr(module, item)