Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/vis_utils.py: 13%

188 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16 

17"""Utilities related to model visualization.""" 

18 

19import os 

20import sys 

21 

22import tensorflow.compat.v2 as tf 

23 

24from keras.src.utils import io_utils 

25from keras.src.utils import layer_utils 

26 

27# isort: off 

28from tensorflow.python.util.tf_export import keras_export 

29 

30try: 

31 # pydot-ng is a fork of pydot that is better maintained. 

32 import pydot_ng as pydot 

33except ImportError: 

34 # pydotplus is an improved version of pydot 

35 try: 

36 import pydotplus as pydot 

37 except ImportError: 

38 # Fall back on pydot if necessary. 

39 try: 

40 import pydot 

41 except ImportError: 

42 pydot = None 

43 

44 

45def check_pydot(): 

46 """Returns True if PyDot is available.""" 

47 return pydot is not None 

48 

49 

50def check_graphviz(): 

51 """Returns True if both PyDot and Graphviz are available.""" 

52 if not check_pydot(): 

53 return False 

54 try: 

55 # Attempt to create an image of a blank graph 

56 # to check the pydot/graphviz installation. 

57 pydot.Dot.create(pydot.Dot()) 

58 return True 

59 except (OSError, pydot.InvocationException): 

60 return False 

61 

62 

63def is_wrapped_model(layer): 

64 from keras.src.engine import functional 

65 from keras.src.layers import Wrapper 

66 

67 return isinstance(layer, Wrapper) and isinstance( 

68 layer.layer, functional.Functional 

69 ) 

70 

71 

72def add_edge(dot, src, dst): 

73 if not dot.get_edge(src, dst): 

74 dot.add_edge(pydot.Edge(src, dst)) 

75 

76 

77@keras_export("keras.utils.model_to_dot") 

78def model_to_dot( 

79 model, 

80 show_shapes=False, 

81 show_dtype=False, 

82 show_layer_names=True, 

83 rankdir="TB", 

84 expand_nested=False, 

85 dpi=96, 

86 subgraph=False, 

87 layer_range=None, 

88 show_layer_activations=False, 

89 show_trainable=False, 

90): 

91 """Convert a Keras model to dot format. 

92 

93 Args: 

94 model: A Keras model instance. 

95 show_shapes: whether to display shape information. 

96 show_dtype: whether to display layer dtypes. 

97 show_layer_names: whether to display layer names. 

98 rankdir: `rankdir` argument passed to PyDot, 

99 a string specifying the format of the plot: 

100 'TB' creates a vertical plot; 

101 'LR' creates a horizontal plot. 

102 expand_nested: whether to expand nested models into clusters. 

103 dpi: Dots per inch. 

104 subgraph: whether to return a `pydot.Cluster` instance. 

105 layer_range: input of `list` containing two `str` items, which is the 

106 starting layer name and ending layer name (both inclusive) indicating 

107 the range of layers for which the `pydot.Dot` will be generated. It 

108 also accepts regex patterns instead of exact name. In such case, start 

109 predicate will be the first element it matches to `layer_range[0]` 

110 and the end predicate will be the last element it matches to 

111 `layer_range[1]`. By default `None` which considers all layers of 

112 model. Note that you must pass range such that the resultant subgraph 

113 must be complete. 

114 show_layer_activations: Display layer activations (only for layers that 

115 have an `activation` property). 

116 show_trainable: whether to display if a layer is trainable. Displays 'T' 

117 when the layer is trainable and 'NT' when it is not trainable. 

118 

119 Returns: 

120 A `pydot.Dot` instance representing the Keras model or 

121 a `pydot.Cluster` instance representing nested model if 

122 `subgraph=True`. 

123 

124 Raises: 

125 ValueError: if `model_to_dot` is called before the model is built. 

126 ImportError: if pydot is not available. 

127 """ 

128 

129 if not model.built: 

130 raise ValueError( 

131 "This model has not yet been built. " 

132 "Build the model first by calling `build()` or by calling " 

133 "the model on a batch of data." 

134 ) 

135 

136 from keras.src.engine import functional 

137 from keras.src.engine import sequential 

138 from keras.src.layers import Wrapper 

139 

140 if not check_pydot(): 

141 raise ImportError( 

142 "You must install pydot (`pip install pydot`) for " 

143 "model_to_dot to work." 

144 ) 

145 

146 if subgraph: 

147 dot = pydot.Cluster(style="dashed", graph_name=model.name) 

148 dot.set("label", model.name) 

149 dot.set("labeljust", "l") 

150 else: 

151 dot = pydot.Dot() 

152 dot.set("rankdir", rankdir) 

153 dot.set("concentrate", True) 

154 dot.set("dpi", dpi) 

155 dot.set_node_defaults(shape="record") 

156 

157 if layer_range is not None: 

158 if len(layer_range) != 2: 

159 raise ValueError( 

160 "layer_range must be of shape (2,). Received: " 

161 f"layer_range = {layer_range} of length {len(layer_range)}" 

162 ) 

163 if not isinstance(layer_range[0], str) or not isinstance( 

164 layer_range[1], str 

165 ): 

166 raise ValueError( 

167 "layer_range should contain string type only. " 

168 f"Received: {layer_range}" 

169 ) 

170 layer_range = layer_utils.get_layer_index_bound_by_layer_name( 

171 model, layer_range 

172 ) 

173 if layer_range[0] < 0 or layer_range[1] > len(model.layers): 

174 raise ValueError( 

175 "Both values in layer_range should be in range (0, " 

176 f"{len(model.layers)}. Received: {layer_range}" 

177 ) 

178 

179 sub_n_first_node = {} 

180 sub_n_last_node = {} 

181 sub_w_first_node = {} 

182 sub_w_last_node = {} 

183 

184 layers = model.layers 

185 if not model._is_graph_network: 

186 node = pydot.Node(str(id(model)), label=model.name) 

187 dot.add_node(node) 

188 return dot 

189 elif isinstance(model, sequential.Sequential): 

190 if not model.built: 

191 model.build() 

192 layers = super(sequential.Sequential, model).layers 

193 

194 # Create graph nodes. 

195 for i, layer in enumerate(layers): 

196 if (layer_range) and (i < layer_range[0] or i >= layer_range[1]): 

197 continue 

198 

199 layer_id = str(id(layer)) 

200 

201 # Append a wrapped layer's label to node's label, if it exists. 

202 layer_name = layer.name 

203 class_name = layer.__class__.__name__ 

204 

205 if isinstance(layer, Wrapper): 

206 if expand_nested and isinstance(layer.layer, functional.Functional): 

207 submodel_wrapper = model_to_dot( 

208 layer.layer, 

209 show_shapes, 

210 show_dtype, 

211 show_layer_names, 

212 rankdir, 

213 expand_nested, 

214 subgraph=True, 

215 show_layer_activations=show_layer_activations, 

216 show_trainable=show_trainable, 

217 ) 

218 # sub_w : submodel_wrapper 

219 sub_w_nodes = submodel_wrapper.get_nodes() 

220 sub_w_first_node[layer.layer.name] = sub_w_nodes[0] 

221 sub_w_last_node[layer.layer.name] = sub_w_nodes[-1] 

222 dot.add_subgraph(submodel_wrapper) 

223 else: 

224 layer_name = f"{layer_name}({layer.layer.name})" 

225 child_class_name = layer.layer.__class__.__name__ 

226 class_name = f"{class_name}({child_class_name})" 

227 

228 if expand_nested and isinstance(layer, functional.Functional): 

229 submodel_not_wrapper = model_to_dot( 

230 layer, 

231 show_shapes, 

232 show_dtype, 

233 show_layer_names, 

234 rankdir, 

235 expand_nested, 

236 subgraph=True, 

237 show_layer_activations=show_layer_activations, 

238 show_trainable=show_trainable, 

239 ) 

240 # sub_n : submodel_not_wrapper 

241 sub_n_nodes = submodel_not_wrapper.get_nodes() 

242 sub_n_first_node[layer.name] = sub_n_nodes[0] 

243 sub_n_last_node[layer.name] = sub_n_nodes[-1] 

244 dot.add_subgraph(submodel_not_wrapper) 

245 

246 # Create node's label. 

247 label = class_name 

248 

249 # Rebuild the label as a table including the layer's activation. 

250 if ( 

251 show_layer_activations 

252 and hasattr(layer, "activation") 

253 and layer.activation is not None 

254 ): 

255 if hasattr(layer.activation, "name"): 

256 activation_name = layer.activation.name 

257 elif hasattr(layer.activation, "__name__"): 

258 activation_name = layer.activation.__name__ 

259 else: 

260 activation_name = str(layer.activation) 

261 label = "{%s|%s}" % (label, activation_name) 

262 

263 # Rebuild the label as a table including the layer's name. 

264 if show_layer_names: 

265 label = f"{layer_name}|{label}" 

266 

267 # Rebuild the label as a table including the layer's dtype. 

268 if show_dtype: 

269 

270 def format_dtype(dtype): 

271 if dtype is None: 

272 return "?" 

273 else: 

274 return str(dtype) 

275 

276 label = f"{label}|{format_dtype(layer.dtype)}" 

277 

278 # Rebuild the label as a table including input/output shapes. 

279 if show_shapes: 

280 

281 def format_shape(shape): 

282 return ( 

283 str(shape) 

284 .replace(str(None), "None") 

285 .replace("{", r"\{") 

286 .replace("}", r"\}") 

287 ) 

288 

289 try: 

290 outputlabels = format_shape(layer.output_shape) 

291 except AttributeError: 

292 outputlabels = "?" 

293 if hasattr(layer, "input_shape"): 

294 inputlabels = format_shape(layer.input_shape) 

295 elif hasattr(layer, "input_shapes"): 

296 inputlabels = ", ".join( 

297 [format_shape(ishape) for ishape in layer.input_shapes] 

298 ) 

299 else: 

300 inputlabels = "?" 

301 label = "{%s}|{input:|output:}|{{%s}|{%s}}" % ( 

302 label, 

303 inputlabels, 

304 outputlabels, 

305 ) 

306 

307 # Rebuild the label as a table including trainable status 

308 if show_trainable: 

309 label = f"{'T' if layer.trainable else 'NT'}|{label}" 

310 

311 if not expand_nested or not isinstance(layer, functional.Functional): 

312 node = pydot.Node(layer_id, label=label) 

313 dot.add_node(node) 

314 

315 # Connect nodes with edges. 

316 for i, layer in enumerate(layers): 

317 if (layer_range) and (i <= layer_range[0] or i >= layer_range[1]): 

318 continue 

319 layer_id = str(id(layer)) 

320 for i, node in enumerate(layer._inbound_nodes): 

321 node_key = layer.name + "_ib-" + str(i) 

322 if node_key in model._network_nodes: 

323 for inbound_layer in tf.nest.flatten(node.inbound_layers): 

324 inbound_layer_id = str(id(inbound_layer)) 

325 if not expand_nested: 

326 assert dot.get_node(inbound_layer_id) 

327 assert dot.get_node(layer_id) 

328 add_edge(dot, inbound_layer_id, layer_id) 

329 else: 

330 # if inbound_layer is not Model or wrapped Model 

331 if not isinstance( 

332 inbound_layer, functional.Functional 

333 ) and not is_wrapped_model(inbound_layer): 

334 # if current layer is not Model or wrapped Model 

335 if not isinstance( 

336 layer, functional.Functional 

337 ) and not is_wrapped_model(layer): 

338 assert dot.get_node(inbound_layer_id) 

339 assert dot.get_node(layer_id) 

340 add_edge(dot, inbound_layer_id, layer_id) 

341 # if current layer is Model 

342 elif isinstance(layer, functional.Functional): 

343 add_edge( 

344 dot, 

345 inbound_layer_id, 

346 sub_n_first_node[layer.name].get_name(), 

347 ) 

348 # if current layer is wrapped Model 

349 elif is_wrapped_model(layer): 

350 add_edge(dot, inbound_layer_id, layer_id) 

351 name = sub_w_first_node[ 

352 layer.layer.name 

353 ].get_name() 

354 add_edge(dot, layer_id, name) 

355 # if inbound_layer is Model 

356 elif isinstance(inbound_layer, functional.Functional): 

357 name = sub_n_last_node[ 

358 inbound_layer.name 

359 ].get_name() 

360 if isinstance(layer, functional.Functional): 

361 output_name = sub_n_first_node[ 

362 layer.name 

363 ].get_name() 

364 add_edge(dot, name, output_name) 

365 else: 

366 add_edge(dot, name, layer_id) 

367 # if inbound_layer is wrapped Model 

368 elif is_wrapped_model(inbound_layer): 

369 inbound_layer_name = inbound_layer.layer.name 

370 add_edge( 

371 dot, 

372 sub_w_last_node[inbound_layer_name].get_name(), 

373 layer_id, 

374 ) 

375 return dot 

376 

377 

378@keras_export("keras.utils.plot_model") 

379def plot_model( 

380 model, 

381 to_file="model.png", 

382 show_shapes=False, 

383 show_dtype=False, 

384 show_layer_names=True, 

385 rankdir="TB", 

386 expand_nested=False, 

387 dpi=96, 

388 layer_range=None, 

389 show_layer_activations=False, 

390 show_trainable=False, 

391): 

392 """Converts a Keras model to dot format and save to a file. 

393 

394 Example: 

395 

396 ```python 

397 input = tf.keras.Input(shape=(100,), dtype='int32', name='input') 

398 x = tf.keras.layers.Embedding( 

399 output_dim=512, input_dim=10000, input_length=100)(input) 

400 x = tf.keras.layers.LSTM(32)(x) 

401 x = tf.keras.layers.Dense(64, activation='relu')(x) 

402 x = tf.keras.layers.Dense(64, activation='relu')(x) 

403 x = tf.keras.layers.Dense(64, activation='relu')(x) 

404 output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x) 

405 model = tf.keras.Model(inputs=[input], outputs=[output]) 

406 dot_img_file = '/tmp/model_1.png' 

407 tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True) 

408 ``` 

409 

410 Args: 

411 model: A Keras model instance 

412 to_file: File name of the plot image. 

413 show_shapes: whether to display shape information. 

414 show_dtype: whether to display layer dtypes. 

415 show_layer_names: whether to display layer names. 

416 rankdir: `rankdir` argument passed to PyDot, 

417 a string specifying the format of the plot: 'TB' creates a vertical 

418 plot; 'LR' creates a horizontal plot. 

419 expand_nested: Whether to expand nested models into clusters. 

420 dpi: Dots per inch. 

421 layer_range: input of `list` containing two `str` items, which is the 

422 starting layer name and ending layer name (both inclusive) indicating 

423 the range of layers for which the plot will be generated. It also 

424 accepts regex patterns instead of exact name. In such case, start 

425 predicate will be the first element it matches to `layer_range[0]` and 

426 the end predicate will be the last element it matches to 

427 `layer_range[1]`. By default `None` which considers all layers of model. 

428 Note that you must pass range such that the resultant subgraph must be 

429 complete. 

430 show_layer_activations: Display layer activations (only for layers that 

431 have an `activation` property). 

432 show_trainable: whether to display if a layer is trainable. Displays 'T' 

433 when the layer is trainable and 'NT' when it is not trainable. 

434 

435 Raises: 

436 ImportError: if graphviz or pydot are not available. 

437 ValueError: if `plot_model` is called before the model is built. 

438 

439 Returns: 

440 A Jupyter notebook Image object if Jupyter is installed. 

441 This enables in-line display of the model plots in notebooks. 

442 """ 

443 

444 if not model.built: 

445 raise ValueError( 

446 "This model has not yet been built. " 

447 "Build the model first by calling `build()` or by calling " 

448 "the model on a batch of data." 

449 ) 

450 

451 if not check_graphviz(): 

452 message = ( 

453 "You must install pydot (`pip install pydot`) " 

454 "and install graphviz " 

455 "(see instructions at https://graphviz.gitlab.io/download/) " 

456 "for plot_model to work." 

457 ) 

458 if "IPython.core.magics.namespace" in sys.modules: 

459 # We don't raise an exception here in order to avoid crashing 

460 # notebook tests where graphviz is not available. 

461 io_utils.print_msg(message) 

462 return 

463 else: 

464 raise ImportError(message) 

465 

466 dot = model_to_dot( 

467 model, 

468 show_shapes=show_shapes, 

469 show_dtype=show_dtype, 

470 show_layer_names=show_layer_names, 

471 rankdir=rankdir, 

472 expand_nested=expand_nested, 

473 dpi=dpi, 

474 layer_range=layer_range, 

475 show_layer_activations=show_layer_activations, 

476 show_trainable=show_trainable, 

477 ) 

478 to_file = io_utils.path_to_string(to_file) 

479 if dot is None: 

480 return 

481 _, extension = os.path.splitext(to_file) 

482 if not extension: 

483 extension = "png" 

484 else: 

485 extension = extension[1:] 

486 # Save image to disk. 

487 dot.write(to_file, format=extension) 

488 # Return the image as a Jupyter Image object, to be displayed in-line. 

489 # Note that we cannot easily detect whether the code is running in a 

490 # notebook, and thus we always return the Image if Jupyter is available. 

491 if extension != "pdf": 

492 try: 

493 from IPython import display 

494 

495 return display.Image(filename=to_file) 

496 except ImportError: 

497 pass 

498