Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/layer_utils.py: 13%

437 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Utilities related to layer/model functionality.""" 

17 

18import copy 

19import functools 

20import re 

21import weakref 

22 

23import numpy as np 

24import tensorflow.compat.v2 as tf 

25 

26from keras.src import initializers 

27from keras.src.utils import io_utils 

28 

29# isort: off 

30from tensorflow.python.util.tf_export import keras_export 

31 

32 

33@keras_export("keras.utils.get_source_inputs") 

34def get_source_inputs(tensor, layer=None, node_index=None): 

35 """Returns the list of input tensors necessary to compute `tensor`. 

36 

37 Output will always be a list of tensors 

38 (potentially with 1 element). 

39 

40 Args: 

41 tensor: The tensor to start from. 

42 layer: Origin layer of the tensor. Will be 

43 determined via tensor._keras_history if not provided. 

44 node_index: Origin node index of the tensor. 

45 

46 Returns: 

47 List of input tensors. 

48 """ 

49 if not hasattr(tensor, "_keras_history"): 

50 return tensor 

51 

52 if layer is None or node_index: 

53 layer, node_index, _ = tensor._keras_history 

54 if not layer._inbound_nodes: 

55 return [tensor] 

56 else: 

57 node = layer._inbound_nodes[node_index] 

58 if node.is_input: 

59 # Reached an Input layer, stop recursion. 

60 return tf.nest.flatten(node.input_tensors) 

61 else: 

62 source_tensors = [] 

63 for layer, node_index, _, tensor in node.iterate_inbound(): 

64 previous_sources = get_source_inputs(tensor, layer, node_index) 

65 # Avoid input redundancy. 

66 for x in previous_sources: 

67 if all(x is not t for t in source_tensors): 

68 source_tensors.append(x) 

69 return source_tensors 

70 

71 

72def validate_string_arg( 

73 input_data, 

74 allowable_strings, 

75 layer_name, 

76 arg_name, 

77 allow_none=False, 

78 allow_callables=False, 

79): 

80 """Validates the correctness of a string-based arg.""" 

81 if allow_none and input_data is None: 

82 return 

83 elif allow_callables and callable(input_data): 

84 return 

85 elif isinstance(input_data, str) and input_data in allowable_strings: 

86 return 

87 else: 

88 allowed_args = "`None`, " if allow_none else "" 

89 allowed_args += "a `Callable`, " if allow_callables else "" 

90 allowed_args += f"or one of the following values: {allowable_strings}" 

91 if allow_callables: 

92 callable_note = ( 

93 f"If restoring a model and `{arg_name}` is a custom callable, " 

94 "please ensure the callable is registered as a custom object. " 

95 "See https://www.tensorflow.org/guide/keras/save_and_serialize" 

96 "#registering_the_custom_object for details. " 

97 ) 

98 else: 

99 callable_note = "" 

100 raise ValueError( 

101 f"Unkown value for `{arg_name}` argument of layer {layer_name}. " 

102 f"{callable_note}Allowed values are: {allowed_args}. Received: " 

103 f"{input_data}" 

104 ) 

105 

106 

107def count_params(weights): 

108 """Count the total number of scalars composing the weights. 

109 

110 Args: 

111 weights: An iterable containing the weights on which to compute params 

112 

113 Returns: 

114 The total number of scalars composing the weights 

115 """ 

116 unique_weights = {id(w): w for w in weights}.values() 

117 # Ignore TrackableWeightHandlers, which will not have a shape defined. 

118 unique_weights = [w for w in unique_weights if hasattr(w, "shape")] 

119 weight_shapes = [w.shape.as_list() for w in unique_weights] 

120 standardized_weight_shapes = [ 

121 [0 if w_i is None else w_i for w_i in w] for w in weight_shapes 

122 ] 

123 return int(sum(np.prod(p) for p in standardized_weight_shapes)) 

124 

125 

126def weight_memory_size(weights): 

127 """Calculate the memory footprint for weights based on their dtypes. 

128 

129 Args: 

130 weights: An iterable contains the weights to compute weight size. 

131 

132 Returns: 

133 The total memory size (in Bytes) of the weights. 

134 """ 

135 unique_weights = {id(w): w for w in weights}.values() 

136 

137 total_memory_size = 0 

138 for w in unique_weights: 

139 # Ignore TrackableWeightHandlers, which will not have a shape defined. 

140 if not hasattr(w, "shape"): 

141 continue 

142 elif None in w.shape.as_list(): 

143 continue 

144 weight_shape = np.prod(w.shape.as_list()) 

145 per_param_size = w.dtype.size 

146 total_memory_size += weight_shape * per_param_size 

147 return total_memory_size 

148 

149 

150def dtensor_variable_summary(weights): 

151 """Group and calculate DTensor based weights memory size. 

152 

153 Since DTensor weights can be sharded across multiple device, the result 

154 will be grouped by the layout/sharding spec for the variables, so that 

155 the accurate per-device memory size can be calculated. 

156 

157 Args: 

158 weights: An iterable contains the weights to compute weight size. 

159 

160 Returns: 

161 total_weight_count, total_memory_size and per_sharing_spec_result which 

162 is a dict with normalized layout spec as key and tuple of weight count 

163 and weight size as value. 

164 """ 

165 unique_weights = {id(w): w for w in weights}.values() 

166 total_weight_count = 0 

167 total_memory_size = 0 

168 per_sharing_spec_result = {} 

169 for w in unique_weights: 

170 # Ignore TrackableWeightHandlers, which will not have a shape defined. 

171 if not hasattr(w, "shape"): 

172 continue 

173 if not isinstance(w, tf.experimental.dtensor.DVariable): 

174 continue 

175 layout = w.layout 

176 # Remove all the duplication axis, and sort the column name. 

177 # 1D replicated and 2D replicated variable will still be fully 

178 # replicated, and [batch, model] sharding will have same memory 

179 # footprint as the [model, batch] layout. 

180 reduced_sharding_spec = list(sorted(set(layout.sharding_specs))) 

181 if tf.experimental.dtensor.UNSHARDED in reduced_sharding_spec: 

182 reduced_sharding_spec.remove(tf.experimental.dtensor.UNSHARDED) 

183 reduced_sharding_spec = tuple(reduced_sharding_spec) # For dict key 

184 weight_count, memory_size = per_sharing_spec_result.get( 

185 reduced_sharding_spec, (0, 0) 

186 ) 

187 reduced_weight_shape = np.prod(w.shape.as_list()) 

188 per_param_size = w.dtype.size 

189 weight_count += reduced_weight_shape 

190 memory_size += reduced_weight_shape * per_param_size 

191 per_sharing_spec_result[reduced_sharding_spec] = ( 

192 weight_count, 

193 memory_size, 

194 ) 

195 total_weight_count += reduced_weight_shape 

196 total_memory_size += reduced_weight_shape * per_param_size 

197 return total_weight_count, total_memory_size, per_sharing_spec_result 

198 

199 

200def print_dtensor_variable_summary(model, print_fn, line_length): 

201 if getattr(model, "_layout_map", None) is not None: 

202 mesh = model._layout_map.get_default_mesh() 

203 elif hasattr(model, "distribute_strategy") and hasattr( 

204 model.distribute_strategy, "_mesh" 

205 ): 

206 mesh = model.distribute_strategy._mesh 

207 else: 

208 # Not running with DTensor 

209 mesh = None 

210 if mesh: 

211 ( 

212 total_weight_count, 

213 total_memory_size, 

214 per_sharing_spec_result, 

215 ) = dtensor_variable_summary(model.weights) 

216 total_per_device_memory_size = 0 

217 for sharding_spec in sorted(per_sharing_spec_result.keys()): 

218 count, memory_size = per_sharing_spec_result[sharding_spec] 

219 if len(sharding_spec) == 0: 

220 print_fn( 

221 f"{count} / {total_weight_count} params " 

222 f"({readable_memory_size(memory_size)}) " 

223 "are fully replicated" 

224 ) 

225 per_device_size = memory_size 

226 else: 

227 sharding_factor = np.prod( 

228 [mesh.dim_size(s) for s in sharding_spec] 

229 ) 

230 per_device_size = memory_size / sharding_factor 

231 print_fn( 

232 f"{count} / {total_weight_count} params " 

233 f"({readable_memory_size(memory_size)}) are sharded based " 

234 f"on spec '{sharding_spec}' and across {sharding_factor} " 

235 f"devices." 

236 ) 

237 total_per_device_memory_size += per_device_size 

238 print_fn( 

239 "Overall per device memory usage: " 

240 f"{readable_memory_size(total_per_device_memory_size)}" 

241 ) 

242 print_fn( 

243 "Overall sharding factor: {:.2f}".format( 

244 total_memory_size / total_per_device_memory_size 

245 ) 

246 ) 

247 print_fn("_" * line_length) 

248 

249 

250def readable_memory_size(weight_memory_size): 

251 """Convert the weight memory size (Bytes) to a readable string.""" 

252 units = ["Byte", "KB", "MB", "GB", "TB", "PB"] 

253 scale = 1024 

254 for unit in units: 

255 if weight_memory_size / scale < 1: 

256 return "{:.2f} {}".format(weight_memory_size, unit) 

257 else: 

258 weight_memory_size /= scale 

259 return "{:.2f} {}".format(weight_memory_size, units[-1]) 

260 

261 

262def get_layer_index_bound_by_layer_name(model, layer_range=None): 

263 """Get the layer indexes from the model based on layer names. 

264 

265 The layer indexes can be used to slice the model into sub models for 

266 display. 

267 

268 Args: 

269 model: `tf.keras.Model` instance. 

270 layer_names: a list or tuple of 2 strings, the starting layer name and 

271 ending layer name (both inclusive) for the result. All layers will 

272 be included when `None` is provided. 

273 

274 Returns: 

275 The index value of layer based on its unique name (layer_names). 

276 Output will be [first_layer_index, last_layer_index + 1]. 

277 """ 

278 if layer_range is not None: 

279 if len(layer_range) != 2: 

280 raise ValueError( 

281 "layer_range must be a list or tuple of length 2. Received: " 

282 f"layer_range = {layer_range} of length {len(layer_range)}" 

283 ) 

284 if not isinstance(layer_range[0], str) or not isinstance( 

285 layer_range[1], str 

286 ): 

287 raise ValueError( 

288 "layer_range should contain string type only. " 

289 f"Received: {layer_range}" 

290 ) 

291 else: 

292 return [0, len(model.layers)] 

293 

294 lower_index = [ 

295 idx 

296 for idx, layer in enumerate(model.layers) 

297 if re.match(layer_range[0], layer.name) 

298 ] 

299 upper_index = [ 

300 idx 

301 for idx, layer in enumerate(model.layers) 

302 if re.match(layer_range[1], layer.name) 

303 ] 

304 

305 if not lower_index or not upper_index: 

306 raise ValueError( 

307 "Passed layer_names do not match the layer names in the model. " 

308 f"Received: {layer_range}" 

309 ) 

310 

311 if min(lower_index) > max(upper_index): 

312 return [min(upper_index), max(lower_index) + 1] 

313 return [min(lower_index), max(upper_index) + 1] 

314 

315 

316def print_summary( 

317 model, 

318 line_length=None, 

319 positions=None, 

320 print_fn=None, 

321 expand_nested=False, 

322 show_trainable=False, 

323 layer_range=None, 

324): 

325 """Prints a summary of a model. 

326 

327 Args: 

328 model: Keras model instance. 

329 line_length: Total length of printed lines 

330 (e.g. set this to adapt the display to different 

331 terminal window sizes). 

332 positions: Relative or absolute positions of log elements in each line. 

333 If not provided, defaults to `[0.3, 0.6, 0.70, 1.]`. 

334 print_fn: Print function to use. 

335 It will be called on each line of the summary. 

336 You can set it to a custom function 

337 in order to capture the string summary. 

338 It defaults to `print` (prints to stdout). 

339 expand_nested: Whether to expand the nested models. 

340 If not provided, defaults to `False`. 

341 show_trainable: Whether to show if a layer is trainable. 

342 If not provided, defaults to `False`. 

343 layer_range: List or tuple containing two strings, 

344 the starting layer name and ending layer name (both inclusive), 

345 indicating the range of layers to be printed in the summary. The 

346 strings could also be regexes instead of an exact name. In this 

347 case, the starting layer will be the first layer that matches 

348 `layer_range[0]` and the ending layer will be the last element that 

349 matches `layer_range[1]`. By default (`None`) all 

350 layers in the model are included in the summary. 

351 """ 

352 if print_fn is None: 

353 print_fn = io_utils.print_msg 

354 

355 if model.__class__.__name__ == "Sequential": 

356 sequential_like = True 

357 elif not model._is_graph_network: 

358 # We treat subclassed models as a simple sequence of layers, for logging 

359 # purposes. 

360 sequential_like = True 

361 else: 

362 sequential_like = True 

363 nodes_by_depth = model._nodes_by_depth.values() 

364 nodes = [] 

365 for v in nodes_by_depth: 

366 if (len(v) > 1) or ( 

367 len(v) == 1 and len(tf.nest.flatten(v[0].keras_inputs)) > 1 

368 ): 

369 # if the model has multiple nodes 

370 # or if the nodes have multiple inbound_layers 

371 # the model is no longer sequential 

372 sequential_like = False 

373 break 

374 nodes += v 

375 if sequential_like: 

376 # search for shared layers 

377 for layer in model.layers: 

378 flag = False 

379 for node in layer._inbound_nodes: 

380 if node in nodes: 

381 if flag: 

382 sequential_like = False 

383 break 

384 else: 

385 flag = True 

386 if not sequential_like: 

387 break 

388 

389 if sequential_like: 

390 line_length = line_length or 65 

391 positions = positions or [0.45, 0.85, 1.0] 

392 if positions[-1] <= 1: 

393 positions = [int(line_length * p) for p in positions] 

394 # header names for the different log elements 

395 to_display = ["Layer (type)", "Output Shape", "Param #"] 

396 else: 

397 line_length = line_length or 98 

398 positions = positions or [0.3, 0.6, 0.70, 1.0] 

399 if positions[-1] <= 1: 

400 positions = [int(line_length * p) for p in positions] 

401 # header names for the different log elements 

402 to_display = ["Layer (type)", "Output Shape", "Param #", "Connected to"] 

403 relevant_nodes = [] 

404 for v in model._nodes_by_depth.values(): 

405 relevant_nodes += v 

406 

407 if show_trainable: 

408 line_length += 11 

409 positions.append(line_length) 

410 to_display.append("Trainable") 

411 

412 layer_range = get_layer_index_bound_by_layer_name(model, layer_range) 

413 

414 def print_row(fields, positions, nested_level=0): 

415 left_to_print = [str(x) for x in fields] 

416 while any(left_to_print): 

417 line = "" 

418 for col in range(len(left_to_print)): 

419 if col > 0: 

420 start_pos = positions[col - 1] 

421 else: 

422 start_pos = 0 

423 end_pos = positions[col] 

424 # Leave room for 2 spaces to delineate columns 

425 # we don't need any if we are printing the last column 

426 space = 2 if col != len(positions) - 1 else 0 

427 cutoff = end_pos - start_pos - space 

428 # Except for last col, offset by one to align the start of col 

429 if col != len(positions) - 1: 

430 cutoff -= 1 

431 if col == 0: 

432 cutoff -= nested_level 

433 fit_into_line = left_to_print[col][:cutoff] 

434 # For nicer formatting we line-break on seeing end of 

435 # tuple/dict etc. 

436 line_break_conditions = ("),", "},", "],", "',") 

437 candidate_cutoffs = [ 

438 fit_into_line.find(x) + len(x) 

439 for x in line_break_conditions 

440 if fit_into_line.find(x) >= 0 

441 ] 

442 if candidate_cutoffs: 

443 cutoff = min(candidate_cutoffs) 

444 fit_into_line = fit_into_line[:cutoff] 

445 

446 if col == 0: 

447 line += "|" * nested_level + " " 

448 line += fit_into_line 

449 line += " " * space if space else "" 

450 left_to_print[col] = left_to_print[col][cutoff:] 

451 

452 # Pad out to the next position 

453 # Make space for nested_level for last column 

454 if nested_level and col == len(positions) - 1: 

455 line += " " * (positions[col] - len(line) - nested_level) 

456 else: 

457 line += " " * (positions[col] - len(line)) 

458 line += "|" * nested_level 

459 print_fn(line) 

460 

461 print_fn(f'Model: "{model.name}"') 

462 print_fn("_" * line_length) 

463 print_row(to_display, positions) 

464 print_fn("=" * line_length) 

465 

466 def print_layer_summary(layer, nested_level=0): 

467 """Prints a summary for a single layer. 

468 

469 Args: 

470 layer: target layer. 

471 nested_level: level of nesting of the layer inside its parent layer 

472 (e.g. 0 for a top-level layer, 1 for a nested layer). 

473 """ 

474 try: 

475 output_shape = layer.output_shape 

476 except AttributeError: 

477 output_shape = "multiple" 

478 except RuntimeError: # output_shape unknown in Eager mode. 

479 output_shape = "?" 

480 name = layer.name 

481 cls_name = layer.__class__.__name__ 

482 if not layer.built and not getattr(layer, "_is_graph_network", False): 

483 # If a subclassed model has a layer that is not called in 

484 # Model.call, the layer will not be built and we cannot call 

485 # layer.count_params(). 

486 params = "0 (unused)" 

487 else: 

488 params = layer.count_params() 

489 fields = [name + " (" + cls_name + ")", output_shape, params] 

490 

491 if show_trainable: 

492 fields.append("Y" if layer.trainable else "N") 

493 

494 print_row(fields, positions, nested_level) 

495 

496 def print_layer_summary_with_connections(layer, nested_level=0): 

497 """Prints a summary for a single layer (including its connections). 

498 

499 Args: 

500 layer: target layer. 

501 nested_level: level of nesting of the layer inside its parent layer 

502 (e.g. 0 for a top-level layer, 1 for a nested layer). 

503 """ 

504 try: 

505 output_shape = layer.output_shape 

506 except AttributeError: 

507 output_shape = "multiple" 

508 connections = [] 

509 for node in layer._inbound_nodes: 

510 if relevant_nodes and node not in relevant_nodes: 

511 # node is not part of the current network 

512 continue 

513 

514 for ( 

515 inbound_layer, 

516 node_index, 

517 tensor_index, 

518 _, 

519 ) in node.iterate_inbound(): 

520 connections.append( 

521 f"{inbound_layer.name}[{node_index}][{tensor_index}]" 

522 ) 

523 

524 name = layer.name 

525 cls_name = layer.__class__.__name__ 

526 fields = [ 

527 name + " (" + cls_name + ")", 

528 output_shape, 

529 layer.count_params(), 

530 connections, 

531 ] 

532 

533 if show_trainable: 

534 fields.append("Y" if layer.trainable else "N") 

535 

536 print_row(fields, positions, nested_level) 

537 

538 def print_layer(layer, nested_level=0, is_nested_last=False): 

539 if sequential_like: 

540 print_layer_summary(layer, nested_level) 

541 else: 

542 print_layer_summary_with_connections(layer, nested_level) 

543 

544 if expand_nested and hasattr(layer, "layers") and layer.layers: 

545 print_fn( 

546 "|" * (nested_level + 1) 

547 + "¯" * (line_length - 2 * nested_level - 2) 

548 + "|" * (nested_level + 1) 

549 ) 

550 

551 nested_layer = layer.layers 

552 is_nested_last = False 

553 for i in range(len(nested_layer)): 

554 if i == len(nested_layer) - 1: 

555 is_nested_last = True 

556 print_layer(nested_layer[i], nested_level + 1, is_nested_last) 

557 

558 print_fn( 

559 "|" * nested_level 

560 + "¯" * (line_length - 2 * nested_level) 

561 + "|" * nested_level 

562 ) 

563 

564 if not is_nested_last: 

565 print_fn( 

566 "|" * nested_level 

567 + " " * (line_length - 2 * nested_level) 

568 + "|" * nested_level 

569 ) 

570 

571 for layer in model.layers[layer_range[0] : layer_range[1]]: 

572 print_layer(layer) 

573 print_fn("=" * line_length) 

574 

575 if hasattr(model, "_collected_trainable_weights"): 

576 trainable_count = count_params(model._collected_trainable_weights) 

577 trainable_memory_size = weight_memory_size( 

578 model._collected_trainable_weights 

579 ) 

580 else: 

581 trainable_count = count_params(model.trainable_weights) 

582 trainable_memory_size = weight_memory_size(model.trainable_weights) 

583 

584 non_trainable_count = count_params(model.non_trainable_weights) 

585 non_trainable_memory_size = weight_memory_size(model.non_trainable_weights) 

586 

587 total_memory_size = trainable_memory_size + non_trainable_memory_size 

588 

589 print_fn( 

590 f"Total params: {trainable_count + non_trainable_count} " 

591 f"({readable_memory_size(total_memory_size)})" 

592 ) 

593 print_fn( 

594 f"Trainable params: {trainable_count} " 

595 f"({readable_memory_size(trainable_memory_size)})" 

596 ) 

597 print_fn( 

598 f"Non-trainable params: {non_trainable_count} " 

599 f"({readable_memory_size(non_trainable_memory_size)})" 

600 ) 

601 print_fn("_" * line_length) 

602 

603 print_dtensor_variable_summary(model, print_fn, line_length) 

604 

605 

606def convert_dense_weights_data_format( 

607 dense, previous_feature_map_shape, target_data_format="channels_first" 

608): 

609 """Utility useful when changing a convnet's `data_format`. 

610 

611 When porting the weights of a convnet from one data format to the other, 

612 if the convnet includes a `Flatten` layer 

613 (applied to the last convolutional feature map) 

614 followed by a `Dense` layer, the weights of that `Dense` layer 

615 should be updated to reflect the new dimension ordering. 

616 

617 Args: 

618 dense: The target `Dense` layer. 

619 previous_feature_map_shape: A shape tuple of 3 integers, 

620 e.g. `(512, 7, 7)`. The shape of the convolutional 

621 feature map right before the `Flatten` layer that 

622 came before the target `Dense` layer. 

623 target_data_format: One of "channels_last", "channels_first". 

624 Set it "channels_last" 

625 if converting a "channels_first" model to "channels_last", 

626 or reciprocally. 

627 """ 

628 assert target_data_format in {"channels_last", "channels_first"} 

629 kernel, bias = dense.get_weights() 

630 for i in range(kernel.shape[1]): 

631 if target_data_format == "channels_first": 

632 c, h, w = previous_feature_map_shape 

633 original_fm_shape = (h, w, c) 

634 ki = kernel[:, i].reshape(original_fm_shape) 

635 ki = np.transpose(ki, (2, 0, 1)) # last -> first 

636 else: 

637 h, w, c = previous_feature_map_shape 

638 original_fm_shape = (c, h, w) 

639 ki = kernel[:, i].reshape(original_fm_shape) 

640 ki = np.transpose(ki, (1, 2, 0)) # first -> last 

641 kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),)) 

642 dense.set_weights([kernel, bias]) 

643 

644 

645def is_builtin_layer(layer): 

646 if not getattr(layer, "_keras_api_names", None): 

647 return False 

648 

649 # Subclasses of `Layer` that are not exported inherit the export name 

650 # of the base layer class. 

651 return layer._keras_api_names != ( 

652 "keras.layers.Layer", 

653 ) and layer._keras_api_names_v1 != ("keras.layers.Layer",) 

654 

655 

656def cached_per_instance(f): 

657 """Lightweight decorator for caching lazily constructed properties. 

658 

659 When to use: 

660 This decorator provides simple caching with minimal overhead. It is designed 

661 for properties which are expensive to compute and static over the life of a 

662 class instance, and provides no mechanism for cache invalidation. Thus it is 

663 best suited for lazily exposing derived properties of other static data. 

664 

665 For classes with custom getattr / setattr behavior (such as trackable 

666 objects), storing cache results as object attributes is not performant. 

667 Instead, a specialized cache can significantly reduce property lookup 

668 overhead. (While still allowing the decorated property to be lazily 

669 computed.) Consider the following class: 

670 

671 ``` 

672 class MyClass: 

673 def __setattr__(self, key, value): 

674 # Some expensive class specific code 

675 # ... 

676 # ... 

677 

678 super(MyClass, self).__setattr__(key, value) 

679 

680 @property 

681 def thing(self): 

682 # `thing` is expensive to compute (and may not even be requested), so we 

683 # want to lazily compute it and then cache it. 

684 output = getattr(self, '_thing', None) 

685 if output is None: 

686 self._thing = output = compute_thing(self) 

687 return output 

688 ``` 

689 

690 It's also worth noting that ANY overriding of __setattr__, even something as 

691 simple as: 

692 ``` 

693 def __setattr__(self, key, value): 

694 super(MyClass, self).__setattr__(key, value) 

695 ``` 

696 

697 Slows down attribute assignment by nearly 10x. 

698 

699 By contrast, replacing the definition of `thing` with the following 

700 sidesteps the expensive __setattr__ altogether: 

701 

702 ''' 

703 @property 

704 @tracking.cached_per_instance 

705 def thing(self): 

706 # `thing` is expensive to compute (and may not even be requested), so we 

707 # want to lazily compute it and then cache it. 

708 return compute_thing(self) 

709 ''' 

710 

711 Performance: 

712 The overhead for this decorator is ~0.4 us / call. A much lower overhead 

713 implementation (~0.085 us / call) can be achieved by using a custom dict 

714 type: 

715 

716 ``` 

717 def dict_based_cache(f): 

718 class Cache(dict): 

719 __slots__ = () 

720 def __missing__(self, key): 

721 self[key] = output = f(key) 

722 return output 

723 

724 return property(Cache().__getitem__) 

725 ``` 

726 

727 However, that implementation holds class instances as keys, and as a result 

728 blocks garbage collection. (And modifying it to use weakref's as keys raises 

729 the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary 

730 implementation below turns out to be more prudent. 

731 

732 Args: 

733 f: The function to cache. 

734 

735 Returns: 

736 f decorated with simple caching behavior. 

737 """ 

738 

739 cache = weakref.WeakKeyDictionary() 

740 

741 @functools.wraps(f) 

742 def wrapped(item): 

743 output = cache.get(item) 

744 if output is None: 

745 cache[item] = output = f(item) 

746 return output 

747 

748 wrapped.cache = cache 

749 return wrapped 

750 

751 

752def filter_empty_layer_containers(layer_list): 

753 """Filter out empty Layer-like containers and uniquify.""" 

754 # TODO(b/130381733): Make this an attribute in base_layer.Layer. 

755 existing = set() 

756 to_visit = layer_list[::-1] 

757 while to_visit: 

758 obj = to_visit.pop() 

759 if id(obj) in existing: 

760 continue 

761 existing.add(id(obj)) 

762 if hasattr(obj, "_is_layer") and not isinstance(obj, type): 

763 yield obj 

764 else: 

765 sub_layers = getattr(obj, "layers", None) or [] 

766 

767 # Trackable data structures will not show up in ".layers" lists, but 

768 # the layers they contain will. 

769 to_visit.extend(sub_layers[::-1]) 

770 

771 

772class CallFunctionSpec: 

773 """Caches the spec and provides utilities for handling call function 

774 args.""" 

775 

776 def __init__(self, full_argspec): 

777 """Initialies a `CallFunctionSpec`. 

778 

779 Args: 

780 full_argspec: the FullArgSpec of a call function of a layer. 

781 """ 

782 self._full_argspec = full_argspec 

783 

784 self._arg_names = list(self._full_argspec.args) 

785 # Scrub `self` that appears if a decorator was applied. 

786 if self._arg_names and self._arg_names[0] == "self": 

787 self._arg_names = self._arg_names[1:] 

788 self._arg_names += self._full_argspec.kwonlyargs or [] 

789 

790 call_accepts_kwargs = self._full_argspec.varkw is not None 

791 self._expects_training_arg = ( 

792 "training" in self._arg_names or call_accepts_kwargs 

793 ) 

794 self._expects_mask_arg = ( 

795 "mask" in self._arg_names or call_accepts_kwargs 

796 ) 

797 

798 call_fn_defaults = self._full_argspec.defaults or [] 

799 defaults = dict() 

800 # The call arg defaults are an n-tuple of the last n elements of the 

801 # args list. (n = # of elements that have a default argument) 

802 for i in range(-1 * len(call_fn_defaults), 0): 

803 defaults[self._arg_names[i]] = call_fn_defaults[i] 

804 # The default training arg will be any (non-None) default specified in 

805 # the method signature, or None if no value is specified. 

806 defaults.update(self._full_argspec.kwonlydefaults or {}) 

807 self._default_training_arg = defaults.get("training") 

808 

809 @property 

810 def full_argspec(self): 

811 """Returns the FullArgSpec of the call function.""" 

812 return self._full_argspec 

813 

814 @property 

815 def arg_names(self): 

816 """List of names of args and kwonlyargs.""" 

817 # `arg_names` is not accurate if the layer has variable positional args. 

818 return self._arg_names 

819 

820 @arg_names.setter 

821 def arg_names(self, value): 

822 self._arg_names = value 

823 

824 @property 

825 @cached_per_instance 

826 def arg_positions(self): 

827 """Returns a dict mapping arg names to their index positions.""" 

828 # `arg_positions` is not accurate if the layer has variable positional 

829 # args. 

830 call_fn_arg_positions = dict() 

831 for pos, arg in enumerate(self._arg_names): 

832 call_fn_arg_positions[arg] = pos 

833 return call_fn_arg_positions 

834 

835 @property 

836 def expects_training_arg(self): 

837 """Whether the call function uses 'training' as a parameter.""" 

838 return self._expects_training_arg 

839 

840 @expects_training_arg.setter 

841 def expects_training_arg(self, value): 

842 self._expects_training_arg = value 

843 

844 @property 

845 def expects_mask_arg(self): 

846 """Whether the call function uses `mask` as a parameter.""" 

847 return self._expects_mask_arg 

848 

849 @expects_mask_arg.setter 

850 def expects_mask_arg(self, value): 

851 self._expects_mask_arg = value 

852 

853 @property 

854 def default_training_arg(self): 

855 """The default value given to the "training" argument.""" 

856 return self._default_training_arg 

857 

858 def arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False): 

859 """Returns true if argument is present in `args` or `kwargs`. 

860 

861 Args: 

862 arg_name: String name of the argument to find. 

863 args: Tuple of args passed to the call function. 

864 kwargs: Dictionary of kwargs passed to the call function. 

865 inputs_in_args: Whether the input argument (the first argument in the 

866 call function) is included in `args`. Defaults to `False`. 

867 

868 Returns: 

869 True if argument with `arg_name` is present in `args` or `kwargs`. 

870 """ 

871 # Performance optimization: do no work in most common case. 

872 if not args and not kwargs: 

873 return False 

874 

875 if arg_name in kwargs: 

876 return True 

877 call_fn_args = self._arg_names 

878 if not inputs_in_args: 

879 # Ignore `inputs` arg. 

880 call_fn_args = call_fn_args[1:] 

881 return arg_name in dict(zip(call_fn_args, args)) 

882 

883 def get_arg_value(self, arg_name, args, kwargs, inputs_in_args=False): 

884 """Retrieves the value for the argument with name `arg_name`. 

885 

886 Args: 

887 arg_name: String name of the argument to find. 

888 args: Tuple of args passed to the call function. 

889 kwargs: Dictionary of kwargs passed to the call function. 

890 inputs_in_args: Whether the input argument (the first argument in the 

891 call function) is included in `args`. Defaults to `False`. 

892 

893 Returns: 

894 The value of the argument with name `arg_name`, extracted from `args` 

895 or `kwargs`. 

896 

897 Raises: 

898 KeyError if the value of `arg_name` cannot be found. 

899 """ 

900 if arg_name in kwargs: 

901 return kwargs[arg_name] 

902 call_fn_args = self._arg_names 

903 if not inputs_in_args: 

904 # Ignore `inputs` arg. 

905 call_fn_args = call_fn_args[1:] 

906 args_dict = dict(zip(call_fn_args, args)) 

907 return args_dict[arg_name] 

908 

909 def set_arg_value( 

910 self, 

911 arg_name, 

912 new_value, 

913 args, 

914 kwargs, 

915 inputs_in_args=False, 

916 pop_kwarg_if_none=False, 

917 ): 

918 """Sets the value of an argument into the given args/kwargs. 

919 

920 Args: 

921 arg_name: String name of the argument to find. 

922 new_value: New value to give to the argument. 

923 args: Tuple of args passed to the call function. 

924 kwargs: Dictionary of kwargs passed to the call function. 

925 inputs_in_args: Whether the input argument (the first argument in the 

926 call function) is included in `args`. Defaults to `False`. 

927 pop_kwarg_if_none: If the new value is `None`, and this is `True`, 

928 then the argument is deleted from `kwargs`. 

929 

930 Returns: 

931 The updated `(args, kwargs)`. 

932 """ 

933 if self.full_argspec.varargs: 

934 try: 

935 arg_pos = self.full_argspec.args.index(arg_name) 

936 if self.full_argspec.args[0] == "self": 

937 arg_pos -= 1 

938 except ValueError: 

939 arg_pos = None 

940 else: 

941 arg_pos = self.arg_positions.get(arg_name, None) 

942 

943 if arg_pos is not None: 

944 if not inputs_in_args: 

945 # Ignore `inputs` arg. 

946 arg_pos = arg_pos - 1 

947 if len(args) > arg_pos: 

948 args = list(args) 

949 args[arg_pos] = new_value 

950 return tuple(args), kwargs 

951 if new_value is None and pop_kwarg_if_none: 

952 kwargs.pop(arg_name, None) 

953 else: 

954 kwargs[arg_name] = new_value 

955 return args, kwargs 

956 

957 def split_out_first_arg(self, args, kwargs): 

958 """Splits (args, kwargs) into (inputs, args, kwargs).""" 

959 # Grab the argument corresponding to the first argument in the 

960 # layer's `call` method spec. This will either be the first positional 

961 # argument, or it will be provided as a keyword argument. 

962 if args: 

963 inputs = args[0] 

964 args = args[1:] 

965 elif self._arg_names[0] in kwargs: 

966 kwargs = copy.copy(kwargs) 

967 inputs = kwargs.pop(self._arg_names[0]) 

968 else: 

969 raise ValueError( 

970 "The first argument to `Layer.call` must always be passed." 

971 ) 

972 return inputs, args, kwargs 

973 

974 

975@keras_export("keras.utils.warmstart_embedding_matrix") 

976def warmstart_embedding_matrix( 

977 base_vocabulary, 

978 new_vocabulary, 

979 base_embeddings, 

980 new_embeddings_initializer="uniform", 

981): 

982 """Warm start embedding matrix with changing vocab. 

983 

984 This util can be used to warmstart the embedding layer matrix when 

985 vocabulary changes between previously saved checkpoint and model. 

986 Vocabulary change could mean, the size of the new vocab is different or the 

987 vocabulary is reshuffled or new vocabulary has been added to old vocabulary. 

988 If the vocabulary size changes, size of the embedding layer matrix also 

989 changes. This util remaps the old vocabulary embeddings to the new embedding 

990 layer matrix. 

991 

992 Example: 

993 Here is an example that demonstrates how to use the 

994 `warmstart_embedding_matrix` util. 

995 >>> import keras 

996 >>> vocab_base = tf.convert_to_tensor(["unk", "a", "b", "c"]) 

997 >>> vocab_new = tf.convert_to_tensor( 

998 ... ["unk", "unk", "a", "b", "c", "d", "e"]) 

999 >>> vectorized_vocab_base = np.random.rand(vocab_base.shape[0], 3) 

1000 >>> vectorized_vocab_new = np.random.rand(vocab_new.shape[0], 3) 

1001 >>> warmstarted_embedding_matrix = warmstart_embedding_matrix( 

1002 ... base_vocabulary=vocab_base, 

1003 ... new_vocabulary=vocab_new, 

1004 ... base_embeddings=vectorized_vocab_base, 

1005 ... new_embeddings_initializer=keras.initializers.Constant( 

1006 ... vectorized_vocab_new)) 

1007 

1008 Here is an example that demonstrates how to get vocabulary and embedding 

1009 weights from layers, use the `warmstart_embedding_matrix` util to remap the 

1010 layer embeddings and continue with model training. 

1011 ``` 

1012 # get old and new vocabulary by using layer.get_vocabulary() 

1013 # for example assume TextVectorization layer is used 

1014 base_vocabulary = old_text_vectorization_layer.get_vocabulary() 

1015 new_vocabulary = new_text_vectorization_layer.get_vocabulary() 

1016 # get previous embedding layer weights 

1017 embedding_weights_base = model.get_layer('embedding').get_weights()[0] 

1018 warmstarted_embedding = keras.utils.warmstart_embedding_matrix( 

1019 base_vocabulary, 

1020 new_vocabulary, 

1021 base_embeddings=embedding_weights_base, 

1022 new_embeddings_initializer="uniform") 

1023 updated_embedding_variable = tf.Variable(warmstarted_embedding) 

1024 

1025 # update embedding layer weights 

1026 model.layers[1].embeddings = updated_embedding_variable 

1027 model.fit(..) 

1028 # continue with model training 

1029 

1030 ``` 

1031 

1032 Args: 

1033 base_vocabulary: The list of vocabulary terms that 

1034 the preexisting embedding matrix `base_embeddings` represents. 

1035 It can be either a 1D array/tensor or a tuple/list of vocabulary 

1036 terms (strings), or a path to a vocabulary text file. If passing a 

1037 file path, the file should contain one line per term in the 

1038 vocabulary. 

1039 new_vocabulary: The list of vocabulary terms for the new vocabulary 

1040 (same format as above). 

1041 base_embeddings: NumPy array or tensor representing the preexisting 

1042 embedding matrix. 

1043 new_embeddings_initializer: Initializer for embedding vectors for 

1044 previously unseen terms to be added to the new embedding matrix (see 

1045 `keras.initializers`). Defaults to "uniform". new_embedding matrix 

1046 needs to be specified with "constant" initializer. 

1047 matrix. Default value is None. 

1048 

1049 Returns: 

1050 tf.tensor of remapped embedding layer matrix 

1051 

1052 """ 

1053 # convert vocab to list 

1054 base_vocabulary = convert_vocab_to_list(base_vocabulary) 

1055 new_vocabulary = convert_vocab_to_list(new_vocabulary) 

1056 

1057 # Initialize the new embedding layer matrix 

1058 new_embeddings_initializer = initializers.get(new_embeddings_initializer) 

1059 new_embedding = new_embeddings_initializer( 

1060 shape=(len(new_vocabulary), base_embeddings.shape[1]), 

1061 dtype=base_embeddings.dtype, 

1062 ) 

1063 

1064 # create mapping dict {vocab:index} 

1065 base_vocabulary_dict = dict( 

1066 zip(base_vocabulary, range(len(base_vocabulary))) 

1067 ) 

1068 

1069 indices_base_vocabulary = [] 

1070 indices_new_vocabulary = [] 

1071 for index, key in enumerate(new_vocabulary): 

1072 if key in base_vocabulary_dict: 

1073 indices_base_vocabulary.append(base_vocabulary_dict[key]) 

1074 indices_new_vocabulary.append(int(index)) 

1075 

1076 # update embedding matrix 

1077 if indices_base_vocabulary: 

1078 values_to_update = tf.gather(base_embeddings, indices_base_vocabulary) 

1079 new_embedding = tf.tensor_scatter_nd_update( 

1080 new_embedding, 

1081 tf.expand_dims(indices_new_vocabulary, axis=1), 

1082 values_to_update, 

1083 ) 

1084 return new_embedding 

1085 

1086 

1087def convert_vocab_to_list(vocab): 

1088 """Convert input vacabulary to list.""" 

1089 vocab_list = [] 

1090 if tf.is_tensor(vocab): 

1091 vocab_list = list(vocab.numpy()) 

1092 elif isinstance(vocab, (np.ndarray, tuple, list)): 

1093 vocab_list = list(vocab) 

1094 elif isinstance(vocab, str): 

1095 if not tf.io.gfile.exists(vocab): 

1096 raise ValueError(f"Vocabulary file {vocab} does not exist.") 

1097 with tf.io.gfile.GFile(vocab, "r") as vocabulary_file: 

1098 vocab_list = vocabulary_file.read().splitlines() 

1099 else: 

1100 raise ValueError( 

1101 "Vocabulary is expected to be either a NumPy array, " 

1102 "list, 1D tensor or a vocabulary text file. Instead type " 

1103 f"{type(vocab)} was received." 

1104 ) 

1105 if len(vocab_list) == 0: 

1106 raise ValueError( 

1107 "Vocabulary is expected to be either a NumPy array, " 

1108 "list, 1D tensor or a vocabulary text file with at least one token." 

1109 " Received 0 instead." 

1110 ) 

1111 return vocab_list 

1112