Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/generic_utils.py: 21%

258 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python utilities required by Keras.""" 

16 

17import binascii 

18import codecs 

19import importlib 

20import marshal 

21import os 

22import re 

23import sys 

24import time 

25import types as python_types 

26 

27import numpy as np 

28import tensorflow.compat.v2 as tf 

29 

30from keras.src.utils import io_utils 

31from keras.src.utils import tf_inspect 

32 

33# isort: off 

34from tensorflow.python.util.tf_export import keras_export 

35 

36 

37def func_dump(func): 

38 """Serializes a user defined function. 

39 

40 Args: 

41 func: the function to serialize. 

42 

43 Returns: 

44 A tuple `(code, defaults, closure)`. 

45 """ 

46 if os.name == "nt": 

47 raw_code = marshal.dumps(func.__code__).replace(b"\\", b"/") 

48 code = codecs.encode(raw_code, "base64").decode("ascii") 

49 else: 

50 raw_code = marshal.dumps(func.__code__) 

51 code = codecs.encode(raw_code, "base64").decode("ascii") 

52 defaults = func.__defaults__ 

53 if func.__closure__: 

54 closure = tuple(c.cell_contents for c in func.__closure__) 

55 else: 

56 closure = None 

57 return code, defaults, closure 

58 

59 

60def func_load(code, defaults=None, closure=None, globs=None): 

61 """Deserializes a user defined function. 

62 

63 Args: 

64 code: bytecode of the function. 

65 defaults: defaults of the function. 

66 closure: closure of the function. 

67 globs: dictionary of global objects. 

68 

69 Returns: 

70 A function object. 

71 """ 

72 if isinstance(code, (tuple, list)): # unpack previous dump 

73 code, defaults, closure = code 

74 if isinstance(defaults, list): 

75 defaults = tuple(defaults) 

76 

77 def ensure_value_to_cell(value): 

78 """Ensures that a value is converted to a python cell object. 

79 

80 Args: 

81 value: Any value that needs to be casted to the cell type 

82 

83 Returns: 

84 A value wrapped as a cell object (see function "func_load") 

85 """ 

86 

87 def dummy_fn(): 

88 

89 value # just access it so it gets captured in .__closure__ 

90 

91 cell_value = dummy_fn.__closure__[0] 

92 if not isinstance(value, type(cell_value)): 

93 return cell_value 

94 return value 

95 

96 if closure is not None: 

97 closure = tuple(ensure_value_to_cell(_) for _ in closure) 

98 try: 

99 raw_code = codecs.decode(code.encode("ascii"), "base64") 

100 except (UnicodeEncodeError, binascii.Error): 

101 raw_code = code.encode("raw_unicode_escape") 

102 code = marshal.loads(raw_code) 

103 if globs is None: 

104 globs = globals() 

105 return python_types.FunctionType( 

106 code, globs, name=code.co_name, argdefs=defaults, closure=closure 

107 ) 

108 

109 

110def has_arg(fn, name, accept_all=False): 

111 """Checks if a callable accepts a given keyword argument. 

112 

113 Args: 

114 fn: Callable to inspect. 

115 name: Check if `fn` can be called with `name` as a keyword argument. 

116 accept_all: What to return if there is no parameter called `name` but 

117 the function accepts a `**kwargs` argument. 

118 

119 Returns: 

120 bool, whether `fn` accepts a `name` keyword argument. 

121 """ 

122 arg_spec = tf_inspect.getfullargspec(fn) 

123 if accept_all and arg_spec.varkw is not None: 

124 return True 

125 return name in arg_spec.args or name in arg_spec.kwonlyargs 

126 

127 

128@keras_export("keras.utils.Progbar") 

129class Progbar: 

130 """Displays a progress bar. 

131 

132 Args: 

133 target: Total number of steps expected, None if unknown. 

134 width: Progress bar width on screen. 

135 verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 

136 stateful_metrics: Iterable of string names of metrics that should *not* 

137 be averaged over time. Metrics in this list will be displayed as-is. 

138 All others will be averaged by the progbar before display. 

139 interval: Minimum visual progress update interval (in seconds). 

140 unit_name: Display name for step counts (usually "step" or "sample"). 

141 """ 

142 

143 def __init__( 

144 self, 

145 target, 

146 width=30, 

147 verbose=1, 

148 interval=0.05, 

149 stateful_metrics=None, 

150 unit_name="step", 

151 ): 

152 self.target = target 

153 self.width = width 

154 self.verbose = verbose 

155 self.interval = interval 

156 self.unit_name = unit_name 

157 if stateful_metrics: 

158 self.stateful_metrics = set(stateful_metrics) 

159 else: 

160 self.stateful_metrics = set() 

161 

162 self._dynamic_display = ( 

163 (hasattr(sys.stdout, "isatty") and sys.stdout.isatty()) 

164 or "ipykernel" in sys.modules 

165 or "posix" in sys.modules 

166 or "PYCHARM_HOSTED" in os.environ 

167 ) 

168 self._total_width = 0 

169 self._seen_so_far = 0 

170 # We use a dict + list to avoid garbage collection 

171 # issues found in OrderedDict 

172 self._values = {} 

173 self._values_order = [] 

174 self._start = time.time() 

175 self._last_update = 0 

176 self._time_at_epoch_start = self._start 

177 self._time_at_epoch_end = None 

178 self._time_after_first_step = None 

179 

180 def update(self, current, values=None, finalize=None): 

181 """Updates the progress bar. 

182 

183 Args: 

184 current: Index of current step. 

185 values: List of tuples: `(name, value_for_last_step)`. If `name` is 

186 in `stateful_metrics`, `value_for_last_step` will be displayed 

187 as-is. Else, an average of the metric over time will be 

188 displayed. 

189 finalize: Whether this is the last update for the progress bar. If 

190 `None`, defaults to `current >= self.target`. 

191 """ 

192 if finalize is None: 

193 if self.target is None: 

194 finalize = False 

195 else: 

196 finalize = current >= self.target 

197 

198 values = values or [] 

199 for k, v in values: 

200 if k not in self._values_order: 

201 self._values_order.append(k) 

202 if k not in self.stateful_metrics: 

203 # In the case that progress bar doesn't have a target value in 

204 # the first epoch, both on_batch_end and on_epoch_end will be 

205 # called, which will cause 'current' and 'self._seen_so_far' to 

206 # have the same value. Force the minimal value to 1 here, 

207 # otherwise stateful_metric will be 0s. 

208 value_base = max(current - self._seen_so_far, 1) 

209 if k not in self._values: 

210 self._values[k] = [v * value_base, value_base] 

211 else: 

212 self._values[k][0] += v * value_base 

213 self._values[k][1] += value_base 

214 else: 

215 # Stateful metrics output a numeric value. This representation 

216 # means "take an average from a single value" but keeps the 

217 # numeric formatting. 

218 self._values[k] = [v, 1] 

219 self._seen_so_far = current 

220 

221 message = "" 

222 now = time.time() 

223 info = f" - {now - self._start:.0f}s" 

224 if current == self.target: 

225 self._time_at_epoch_end = now 

226 if self.verbose == 1: 

227 if now - self._last_update < self.interval and not finalize: 

228 return 

229 

230 prev_total_width = self._total_width 

231 if self._dynamic_display: 

232 message += "\b" * prev_total_width 

233 message += "\r" 

234 else: 

235 message += "\n" 

236 

237 if self.target is not None: 

238 numdigits = int(np.log10(self.target)) + 1 

239 bar = ("%" + str(numdigits) + "d/%d [") % (current, self.target) 

240 prog = float(current) / self.target 

241 prog_width = int(self.width * prog) 

242 if prog_width > 0: 

243 bar += "=" * (prog_width - 1) 

244 if current < self.target: 

245 bar += ">" 

246 else: 

247 bar += "=" 

248 bar += "." * (self.width - prog_width) 

249 bar += "]" 

250 else: 

251 bar = "%7d/Unknown" % current 

252 

253 self._total_width = len(bar) 

254 message += bar 

255 

256 time_per_unit = self._estimate_step_duration(current, now) 

257 

258 if self.target is None or finalize: 

259 info += self._format_time(time_per_unit, self.unit_name) 

260 else: 

261 eta = time_per_unit * (self.target - current) 

262 if eta > 3600: 

263 eta_format = "%d:%02d:%02d" % ( 

264 eta // 3600, 

265 (eta % 3600) // 60, 

266 eta % 60, 

267 ) 

268 elif eta > 60: 

269 eta_format = "%d:%02d" % (eta // 60, eta % 60) 

270 else: 

271 eta_format = "%ds" % eta 

272 

273 info = f" - ETA: {eta_format}" 

274 

275 for k in self._values_order: 

276 info += f" - {k}:" 

277 if isinstance(self._values[k], list): 

278 avg = np.mean( 

279 self._values[k][0] / max(1, self._values[k][1]) 

280 ) 

281 if abs(avg) > 1e-3: 

282 info += f" {avg:.4f}" 

283 else: 

284 info += f" {avg:.4e}" 

285 else: 

286 info += f" {self._values[k]}" 

287 

288 self._total_width += len(info) 

289 if prev_total_width > self._total_width: 

290 info += " " * (prev_total_width - self._total_width) 

291 

292 if finalize: 

293 info += "\n" 

294 

295 message += info 

296 io_utils.print_msg(message, line_break=False) 

297 message = "" 

298 

299 elif self.verbose == 2: 

300 if finalize: 

301 numdigits = int(np.log10(self.target)) + 1 

302 count = ("%" + str(numdigits) + "d/%d") % (current, self.target) 

303 info = count + info 

304 for k in self._values_order: 

305 info += f" - {k}:" 

306 avg = np.mean( 

307 self._values[k][0] / max(1, self._values[k][1]) 

308 ) 

309 if avg > 1e-3: 

310 info += f" {avg:.4f}" 

311 else: 

312 info += f" {avg:.4e}" 

313 if self._time_at_epoch_end: 

314 time_per_epoch = ( 

315 self._time_at_epoch_end - self._time_at_epoch_start 

316 ) 

317 avg_time_per_step = time_per_epoch / self.target 

318 self._time_at_epoch_start = now 

319 self._time_at_epoch_end = None 

320 info += " -" + self._format_time(time_per_epoch, "epoch") 

321 info += " -" + self._format_time( 

322 avg_time_per_step, self.unit_name 

323 ) 

324 info += "\n" 

325 message += info 

326 io_utils.print_msg(message, line_break=False) 

327 message = "" 

328 

329 self._last_update = now 

330 

331 def add(self, n, values=None): 

332 self.update(self._seen_so_far + n, values) 

333 

334 def _format_time(self, time_per_unit, unit_name): 

335 """format a given duration to display to the user. 

336 

337 Given the duration, this function formats it in either milliseconds 

338 or seconds and displays the unit (i.e. ms/step or s/epoch) 

339 Args: 

340 time_per_unit: the duration to display 

341 unit_name: the name of the unit to display 

342 Returns: 

343 a string with the correctly formatted duration and units 

344 """ 

345 formatted = "" 

346 if time_per_unit >= 1 or time_per_unit == 0: 

347 formatted += f" {time_per_unit:.0f}s/{unit_name}" 

348 elif time_per_unit >= 1e-3: 

349 formatted += f" {time_per_unit * 1000.0:.0f}ms/{unit_name}" 

350 else: 

351 formatted += f" {time_per_unit * 1000000.0:.0f}us/{unit_name}" 

352 return formatted 

353 

354 def _estimate_step_duration(self, current, now): 

355 """Estimate the duration of a single step. 

356 

357 Given the step number `current` and the corresponding time `now` this 

358 function returns an estimate for how long a single step takes. If this 

359 is called before one step has been completed (i.e. `current == 0`) then 

360 zero is given as an estimate. The duration estimate ignores the duration 

361 of the (assumed to be non-representative) first step for estimates when 

362 more steps are available (i.e. `current>1`). 

363 

364 Args: 

365 current: Index of current step. 

366 now: The current time. 

367 

368 Returns: Estimate of the duration of a single step. 

369 """ 

370 if current: 

371 # there are a few special scenarios here: 

372 # 1) somebody is calling the progress bar without ever supplying 

373 # step 1 

374 # 2) somebody is calling the progress bar and supplies step one 

375 # multiple times, e.g. as part of a finalizing call 

376 # in these cases, we just fall back to the simple calculation 

377 if self._time_after_first_step is not None and current > 1: 

378 time_per_unit = (now - self._time_after_first_step) / ( 

379 current - 1 

380 ) 

381 else: 

382 time_per_unit = (now - self._start) / current 

383 

384 if current == 1: 

385 self._time_after_first_step = now 

386 return time_per_unit 

387 else: 

388 return 0 

389 

390 def _update_stateful_metrics(self, stateful_metrics): 

391 self.stateful_metrics = self.stateful_metrics.union(stateful_metrics) 

392 

393 

394def make_batches(size, batch_size): 

395 """Returns a list of batch indices (tuples of indices). 

396 

397 Args: 

398 size: Integer, total size of the data to slice into batches. 

399 batch_size: Integer, batch size. 

400 

401 Returns: 

402 A list of tuples of array indices. 

403 """ 

404 num_batches = int(np.ceil(size / float(batch_size))) 

405 return [ 

406 (i * batch_size, min(size, (i + 1) * batch_size)) 

407 for i in range(0, num_batches) 

408 ] 

409 

410 

411def slice_arrays(arrays, start=None, stop=None): 

412 """Slice an array or list of arrays. 

413 

414 This takes an array-like, or a list of 

415 array-likes, and outputs: 

416 - arrays[start:stop] if `arrays` is an array-like 

417 - [x[start:stop] for x in arrays] if `arrays` is a list 

418 

419 Can also work on list/array of indices: `slice_arrays(x, indices)` 

420 

421 Args: 

422 arrays: Single array or list of arrays. 

423 start: can be an integer index (start index) or a list/array of indices 

424 stop: integer (stop index); should be None if `start` was a list. 

425 

426 Returns: 

427 A slice of the array(s). 

428 

429 Raises: 

430 ValueError: If the value of start is a list and stop is not None. 

431 """ 

432 if arrays is None: 

433 return [None] 

434 if isinstance(start, list) and stop is not None: 

435 raise ValueError( 

436 "The stop argument has to be None if the value of start " 

437 f"is a list. Received start={start}, stop={stop}" 

438 ) 

439 elif isinstance(arrays, list): 

440 if hasattr(start, "__len__"): 

441 # hdf5 datasets only support list objects as indices 

442 if hasattr(start, "shape"): 

443 start = start.tolist() 

444 return [None if x is None else x[start] for x in arrays] 

445 return [ 

446 None 

447 if x is None 

448 else None 

449 if not hasattr(x, "__getitem__") 

450 else x[start:stop] 

451 for x in arrays 

452 ] 

453 else: 

454 if hasattr(start, "__len__"): 

455 if hasattr(start, "shape"): 

456 start = start.tolist() 

457 return arrays[start] 

458 if hasattr(start, "__getitem__"): 

459 return arrays[start:stop] 

460 return [None] 

461 

462 

463def to_list(x): 

464 """Normalizes a list/tensor into a list. 

465 

466 If a tensor is passed, we return 

467 a list of size 1 containing the tensor. 

468 

469 Args: 

470 x: target object to be normalized. 

471 

472 Returns: 

473 A list. 

474 """ 

475 if isinstance(x, list): 

476 return x 

477 return [x] 

478 

479 

480def to_snake_case(name): 

481 intermediate = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) 

482 insecure = re.sub("([a-z])([A-Z])", r"\1_\2", intermediate).lower() 

483 # If the class is private the name starts with "_" which is not secure 

484 # for creating scopes. We prefix the name with "private" in this case. 

485 if insecure[0] != "_": 

486 return insecure 

487 return "private" + insecure 

488 

489 

490def is_all_none(structure): 

491 iterable = tf.nest.flatten(structure) 

492 # We cannot use Python's `any` because the iterable may return Tensors. 

493 for element in iterable: 

494 if element is not None: 

495 return False 

496 return True 

497 

498 

499def check_for_unexpected_keys(name, input_dict, expected_values): 

500 unknown = set(input_dict.keys()).difference(expected_values) 

501 if unknown: 

502 raise ValueError( 

503 f"Unknown entries in {name} dictionary: {list(unknown)}. " 

504 f"Only expected following keys: {expected_values}" 

505 ) 

506 

507 

508def validate_kwargs( 

509 kwargs, allowed_kwargs, error_message="Keyword argument not understood:" 

510): 

511 """Checks that all keyword arguments are in the set of allowed keys.""" 

512 for kwarg in kwargs: 

513 if kwarg not in allowed_kwargs: 

514 raise TypeError(error_message, kwarg) 

515 

516 

517def default(method): 

518 """Decorates a method to detect overrides in subclasses.""" 

519 method._is_default = True 

520 return method 

521 

522 

523def is_default(method): 

524 """Check if a method is decorated with the `default` wrapper.""" 

525 return getattr(method, "_is_default", False) 

526 

527 

528def populate_dict_with_module_objects(target_dict, modules, obj_filter): 

529 for module in modules: 

530 for name in dir(module): 

531 obj = getattr(module, name) 

532 if obj_filter(obj): 

533 target_dict[name] = obj 

534 

535 

536class LazyLoader(python_types.ModuleType): 

537 """Lazily import a module, mainly to avoid pulling in large dependencies.""" 

538 

539 def __init__(self, local_name, parent_module_globals, name): 

540 self._local_name = local_name 

541 self._parent_module_globals = parent_module_globals 

542 super().__init__(name) 

543 

544 def _load(self): 

545 """Load the module and insert it into the parent's globals.""" 

546 # Import the target module and insert it into the parent's namespace 

547 module = importlib.import_module(self.__name__) 

548 self._parent_module_globals[self._local_name] = module 

549 # Update this object's dict so that if someone keeps a reference to the 

550 # LazyLoader, lookups are efficient (__getattr__ is only called on 

551 # lookups that fail). 

552 self.__dict__.update(module.__dict__) 

553 return module 

554 

555 def __getattr__(self, item): 

556 module = self._load() 

557 return getattr(module, item) 

558