Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/layer_utils.py: 11%

186 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=protected-access 

16"""Utilities related to layer/model functionality.""" 

17 

18import functools 

19import weakref 

20 

21import numpy as np 

22 

23from tensorflow.python.util import nest 

24from tensorflow.python.util.tf_export import keras_export 

25 

26 

27@keras_export('keras.utils.get_source_inputs') 

28def get_source_inputs(tensor, layer=None, node_index=None): 

29 """Returns the list of input tensors necessary to compute `tensor`. 

30 

31 Output will always be a list of tensors 

32 (potentially with 1 element). 

33 

34 Args: 

35 tensor: The tensor to start from. 

36 layer: Origin layer of the tensor. Will be 

37 determined via tensor._keras_history if not provided. 

38 node_index: Origin node index of the tensor. 

39 

40 Returns: 

41 List of input tensors. 

42 """ 

43 if not hasattr(tensor, '_keras_history'): 

44 return tensor 

45 

46 if layer is None or node_index: 

47 layer, node_index, _ = tensor._keras_history 

48 if not layer._inbound_nodes: 

49 return [tensor] 

50 else: 

51 node = layer._inbound_nodes[node_index] 

52 if node.is_input: 

53 # Reached an Input layer, stop recursion. 

54 return nest.flatten(node.input_tensors) 

55 else: 

56 source_tensors = [] 

57 for layer, node_index, _, tensor in node.iterate_inbound(): 

58 previous_sources = get_source_inputs(tensor, layer, node_index) 

59 # Avoid input redundancy. 

60 for x in previous_sources: 

61 if all(x is not t for t in source_tensors): 

62 source_tensors.append(x) 

63 return source_tensors 

64 

65 

66def validate_string_arg(input_data, 

67 allowable_strings, 

68 layer_name, 

69 arg_name, 

70 allow_none=False, 

71 allow_callables=False): 

72 """Validates the correctness of a string-based arg.""" 

73 if allow_none and input_data is None: 

74 return 

75 elif allow_callables and callable(input_data): 

76 return 

77 elif isinstance(input_data, str) and input_data in allowable_strings: 

78 return 

79 else: 

80 allowed_args = '`None`, ' if allow_none else '' 

81 allowed_args += 'a `Callable`, ' if allow_callables else '' 

82 allowed_args += 'or one of the following values: %s' % (allowable_strings,) 

83 raise ValueError(('The %s argument of layer %s received an invalid ' 

84 'value %s. Allowed values are: %s.') % 

85 (arg_name, layer_name, input_data, allowed_args)) 

86 

87 

88def count_params(weights): 

89 """Count the total number of scalars composing the weights. 

90 

91 Args: 

92 weights: An iterable containing the weights on which to compute params 

93 

94 Returns: 

95 The total number of scalars composing the weights 

96 """ 

97 unique_weights = {id(w): w for w in weights}.values() 

98 weight_shapes = [w.shape.as_list() for w in unique_weights] 

99 standardized_weight_shapes = [ 

100 [0 if w_i is None else w_i for w_i in w] for w in weight_shapes 

101 ] 

102 return int(sum(np.prod(p) for p in standardized_weight_shapes)) 

103 

104 

105def print_summary(model, line_length=None, positions=None, print_fn=None): 

106 """Prints a summary of a model. 

107 

108 Args: 

109 model: Keras model instance. 

110 line_length: Total length of printed lines 

111 (e.g. set this to adapt the display to different 

112 terminal window sizes). 

113 positions: Relative or absolute positions of log elements in each line. 

114 If not provided, defaults to `[.33, .55, .67, 1.]`. 

115 print_fn: Print function to use. 

116 It will be called on each line of the summary. 

117 You can set it to a custom function 

118 in order to capture the string summary. 

119 It defaults to `print` (prints to stdout). 

120 """ 

121 if print_fn is None: 

122 print_fn = print 

123 

124 if model.__class__.__name__ == 'Sequential': 

125 sequential_like = True 

126 elif not model._is_graph_network: 

127 # We treat subclassed models as a simple sequence of layers, for logging 

128 # purposes. 

129 sequential_like = True 

130 else: 

131 sequential_like = True 

132 nodes_by_depth = model._nodes_by_depth.values() 

133 nodes = [] 

134 for v in nodes_by_depth: 

135 if (len(v) > 1) or (len(v) == 1 and 

136 len(nest.flatten(v[0].keras_inputs)) > 1): 

137 # if the model has multiple nodes 

138 # or if the nodes have multiple inbound_layers 

139 # the model is no longer sequential 

140 sequential_like = False 

141 break 

142 nodes += v 

143 if sequential_like: 

144 # search for shared layers 

145 for layer in model.layers: 

146 flag = False 

147 for node in layer._inbound_nodes: 

148 if node in nodes: 

149 if flag: 

150 sequential_like = False 

151 break 

152 else: 

153 flag = True 

154 if not sequential_like: 

155 break 

156 

157 if sequential_like: 

158 line_length = line_length or 65 

159 positions = positions or [.45, .85, 1.] 

160 if positions[-1] <= 1: 

161 positions = [int(line_length * p) for p in positions] 

162 # header names for the different log elements 

163 to_display = ['Layer (type)', 'Output Shape', 'Param #'] 

164 else: 

165 line_length = line_length or 98 

166 positions = positions or [.33, .55, .67, 1.] 

167 if positions[-1] <= 1: 

168 positions = [int(line_length * p) for p in positions] 

169 # header names for the different log elements 

170 to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to'] 

171 relevant_nodes = [] 

172 for v in model._nodes_by_depth.values(): 

173 relevant_nodes += v 

174 

175 def print_row(fields, positions): 

176 line = '' 

177 for i in range(len(fields)): 

178 if i > 0: 

179 line = line[:-1] + ' ' 

180 line += str(fields[i]) 

181 line = line[:positions[i]] 

182 line += ' ' * (positions[i] - len(line)) 

183 print_fn(line) 

184 

185 print_fn('Model: "{}"'.format(model.name)) 

186 print_fn('_' * line_length) 

187 print_row(to_display, positions) 

188 print_fn('=' * line_length) 

189 

190 def print_layer_summary(layer): 

191 """Prints a summary for a single layer. 

192 

193 Args: 

194 layer: target layer. 

195 """ 

196 try: 

197 output_shape = layer.output_shape 

198 except AttributeError: 

199 output_shape = 'multiple' 

200 except RuntimeError: # output_shape unknown in Eager mode. 

201 output_shape = '?' 

202 name = layer.name 

203 cls_name = layer.__class__.__name__ 

204 if not layer.built and not getattr(layer, '_is_graph_network', False): 

205 # If a subclassed model has a layer that is not called in Model.call, the 

206 # layer will not be built and we cannot call layer.count_params(). 

207 params = '0 (unused)' 

208 else: 

209 params = layer.count_params() 

210 fields = [name + ' (' + cls_name + ')', output_shape, params] 

211 print_row(fields, positions) 

212 

213 def print_layer_summary_with_connections(layer): 

214 """Prints a summary for a single layer (including topological connections). 

215 

216 Args: 

217 layer: target layer. 

218 """ 

219 try: 

220 output_shape = layer.output_shape 

221 except AttributeError: 

222 output_shape = 'multiple' 

223 connections = [] 

224 for node in layer._inbound_nodes: 

225 if relevant_nodes and node not in relevant_nodes: 

226 # node is not part of the current network 

227 continue 

228 

229 for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound(): 

230 connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index, 

231 tensor_index)) 

232 

233 name = layer.name 

234 cls_name = layer.__class__.__name__ 

235 if not connections: 

236 first_connection = '' 

237 else: 

238 first_connection = connections[0] 

239 fields = [ 

240 name + ' (' + cls_name + ')', output_shape, 

241 layer.count_params(), first_connection 

242 ] 

243 print_row(fields, positions) 

244 if len(connections) > 1: 

245 for i in range(1, len(connections)): 

246 fields = ['', '', '', connections[i]] 

247 print_row(fields, positions) 

248 

249 layers = model.layers 

250 for i in range(len(layers)): 

251 if sequential_like: 

252 print_layer_summary(layers[i]) 

253 else: 

254 print_layer_summary_with_connections(layers[i]) 

255 if i == len(layers) - 1: 

256 print_fn('=' * line_length) 

257 else: 

258 print_fn('_' * line_length) 

259 

260 if hasattr(model, '_collected_trainable_weights'): 

261 trainable_count = count_params(model._collected_trainable_weights) 

262 else: 

263 trainable_count = count_params(model.trainable_weights) 

264 

265 non_trainable_count = count_params(model.non_trainable_weights) 

266 

267 print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count)) 

268 print_fn('Trainable params: {:,}'.format(trainable_count)) 

269 print_fn('Non-trainable params: {:,}'.format(non_trainable_count)) 

270 print_fn('_' * line_length) 

271 

272 

273def convert_dense_weights_data_format(dense, 

274 previous_feature_map_shape, 

275 target_data_format='channels_first'): 

276 """Utility useful when changing a convnet's `data_format`. 

277 

278 When porting the weights of a convnet from one data format to the other, 

279 if the convnet includes a `Flatten` layer 

280 (applied to the last convolutional feature map) 

281 followed by a `Dense` layer, the weights of that `Dense` layer 

282 should be updated to reflect the new dimension ordering. 

283 

284 Args: 

285 dense: The target `Dense` layer. 

286 previous_feature_map_shape: A shape tuple of 3 integers, 

287 e.g. `(512, 7, 7)`. The shape of the convolutional 

288 feature map right before the `Flatten` layer that 

289 came before the target `Dense` layer. 

290 target_data_format: One of "channels_last", "channels_first". 

291 Set it "channels_last" 

292 if converting a "channels_first" model to "channels_last", 

293 or reciprocally. 

294 """ 

295 assert target_data_format in {'channels_last', 'channels_first'} 

296 kernel, bias = dense.get_weights() 

297 for i in range(kernel.shape[1]): 

298 if target_data_format == 'channels_first': 

299 c, h, w = previous_feature_map_shape 

300 original_fm_shape = (h, w, c) 

301 ki = kernel[:, i].reshape(original_fm_shape) 

302 ki = np.transpose(ki, (2, 0, 1)) # last -> first 

303 else: 

304 h, w, c = previous_feature_map_shape 

305 original_fm_shape = (c, h, w) 

306 ki = kernel[:, i].reshape(original_fm_shape) 

307 ki = np.transpose(ki, (1, 2, 0)) # first -> last 

308 kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),)) 

309 dense.set_weights([kernel, bias]) 

310 

311 

312def is_builtin_layer(layer): 

313 if not getattr(layer, '_keras_api_names', None): 

314 return False 

315 

316 # Subclasses of `Layer` that are not exported inherit the export name 

317 # of the base layer class. 

318 return (layer._keras_api_names != ('keras.layers.Layer',) and 

319 layer._keras_api_names_v1 != ('keras.layers.Layer',)) 

320 

321 

322def cached_per_instance(f): 

323 """Lightweight decorator for caching lazily constructed properties. 

324 

325 When to use: 

326 This decorator provides simple caching with minimal overhead. It is designed 

327 for properties which are expensive to compute and static over the life of a 

328 class instance, and provides no mechanism for cache invalidation. Thus it is 

329 best suited for lazily exposing derived properties of other static data. 

330 

331 For classes with custom getattr / setattr behavior (such as trackable 

332 objects), storing cache results as object attributes is not performant. 

333 Instead, a specialized cache can significantly reduce property lookup 

334 overhead. (While still allowing the decorated property to be lazily computed.) 

335 Consider the following class: 

336 

337 ``` 

338 class MyClass(object): 

339 def __setattr__(self, key, value): 

340 # Some expensive class specific code 

341 # ... 

342 # ... 

343 

344 super(MyClass, self).__setattr__(key, value) 

345 

346 @property 

347 def thing(self): 

348 # `thing` is expensive to compute (and may not even be requested), so we 

349 # want to lazily compute it and then cache it. 

350 output = getattr(self, '_thing', None) 

351 if output is None: 

352 self._thing = output = compute_thing(self) 

353 return output 

354 ``` 

355 

356 It's also worth noting that ANY overriding of __setattr__, even something as 

357 simple as: 

358 ``` 

359 def __setattr__(self, key, value): 

360 super(MyClass, self).__setattr__(key, value) 

361 ``` 

362 

363 Slows down attribute assignment by nearly 10x. 

364 

365 By contrast, replacing the definition of `thing` with the following sidesteps 

366 the expensive __setattr__ altogether: 

367 

368 ''' 

369 @property 

370 @tracking.cached_per_instance 

371 def thing(self): 

372 # `thing` is expensive to compute (and may not even be requested), so we 

373 # want to lazily compute it and then cache it. 

374 return compute_thing(self) 

375 ''' 

376 

377 Performance: 

378 The overhead for this decorator is ~0.4 us / call. A much lower overhead 

379 implementation (~0.085 us / call) can be achieved by using a custom dict type: 

380 

381 ``` 

382 def dict_based_cache(f): 

383 class Cache(dict): 

384 __slots__ = () 

385 def __missing__(self, key): 

386 self[key] = output = f(key) 

387 return output 

388 

389 return property(Cache().__getitem__) 

390 ``` 

391 

392 However, that implementation holds class instances as keys, and as a result 

393 blocks garbage collection. (And modifying it to use weakref's as keys raises 

394 the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary 

395 implementation below turns out to be more prudent. 

396 

397 Args: 

398 f: The function to cache. 

399 

400 Returns: 

401 f decorated with simple caching behavior. 

402 """ 

403 

404 cache = weakref.WeakKeyDictionary() 

405 

406 @functools.wraps(f) 

407 def wrapped(item): 

408 output = cache.get(item) 

409 if output is None: 

410 cache[item] = output = f(item) 

411 return output 

412 

413 wrapped.cache = cache 

414 return wrapped 

415 

416 

417def filter_empty_layer_containers(layer_list): 

418 """Filter out empty Layer-like containers and uniquify.""" 

419 # TODO(b/130381733): Make this an attribute in base_layer.Layer. 

420 existing = set() 

421 to_visit = layer_list[::-1] 

422 while to_visit: 

423 obj = to_visit.pop() 

424 if id(obj) in existing: 

425 continue 

426 existing.add(id(obj)) 

427 if hasattr(obj, '_is_layer') and not isinstance(obj, type): 

428 yield obj 

429 else: 

430 sub_layers = getattr(obj, 'layers', None) or [] 

431 

432 # Trackable data structures will not show up in ".layers" lists, but 

433 # the layers they contain will. 

434 to_visit.extend(sub_layers[::-1])