Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/vis_utils.py: 13%
188 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
17"""Utilities related to model visualization."""
19import os
20import sys
22import tensorflow.compat.v2 as tf
24from keras.src.utils import io_utils
25from keras.src.utils import layer_utils
27# isort: off
28from tensorflow.python.util.tf_export import keras_export
30try:
31 # pydot-ng is a fork of pydot that is better maintained.
32 import pydot_ng as pydot
33except ImportError:
34 # pydotplus is an improved version of pydot
35 try:
36 import pydotplus as pydot
37 except ImportError:
38 # Fall back on pydot if necessary.
39 try:
40 import pydot
41 except ImportError:
42 pydot = None
45def check_pydot():
46 """Returns True if PyDot is available."""
47 return pydot is not None
50def check_graphviz():
51 """Returns True if both PyDot and Graphviz are available."""
52 if not check_pydot():
53 return False
54 try:
55 # Attempt to create an image of a blank graph
56 # to check the pydot/graphviz installation.
57 pydot.Dot.create(pydot.Dot())
58 return True
59 except (OSError, pydot.InvocationException):
60 return False
63def is_wrapped_model(layer):
64 from keras.src.engine import functional
65 from keras.src.layers import Wrapper
67 return isinstance(layer, Wrapper) and isinstance(
68 layer.layer, functional.Functional
69 )
72def add_edge(dot, src, dst):
73 if not dot.get_edge(src, dst):
74 dot.add_edge(pydot.Edge(src, dst))
77@keras_export("keras.utils.model_to_dot")
78def model_to_dot(
79 model,
80 show_shapes=False,
81 show_dtype=False,
82 show_layer_names=True,
83 rankdir="TB",
84 expand_nested=False,
85 dpi=96,
86 subgraph=False,
87 layer_range=None,
88 show_layer_activations=False,
89 show_trainable=False,
90):
91 """Convert a Keras model to dot format.
93 Args:
94 model: A Keras model instance.
95 show_shapes: whether to display shape information.
96 show_dtype: whether to display layer dtypes.
97 show_layer_names: whether to display layer names.
98 rankdir: `rankdir` argument passed to PyDot,
99 a string specifying the format of the plot:
100 'TB' creates a vertical plot;
101 'LR' creates a horizontal plot.
102 expand_nested: whether to expand nested models into clusters.
103 dpi: Dots per inch.
104 subgraph: whether to return a `pydot.Cluster` instance.
105 layer_range: input of `list` containing two `str` items, which is the
106 starting layer name and ending layer name (both inclusive) indicating
107 the range of layers for which the `pydot.Dot` will be generated. It
108 also accepts regex patterns instead of exact name. In such case, start
109 predicate will be the first element it matches to `layer_range[0]`
110 and the end predicate will be the last element it matches to
111 `layer_range[1]`. By default `None` which considers all layers of
112 model. Note that you must pass range such that the resultant subgraph
113 must be complete.
114 show_layer_activations: Display layer activations (only for layers that
115 have an `activation` property).
116 show_trainable: whether to display if a layer is trainable. Displays 'T'
117 when the layer is trainable and 'NT' when it is not trainable.
119 Returns:
120 A `pydot.Dot` instance representing the Keras model or
121 a `pydot.Cluster` instance representing nested model if
122 `subgraph=True`.
124 Raises:
125 ValueError: if `model_to_dot` is called before the model is built.
126 ImportError: if pydot is not available.
127 """
129 if not model.built:
130 raise ValueError(
131 "This model has not yet been built. "
132 "Build the model first by calling `build()` or by calling "
133 "the model on a batch of data."
134 )
136 from keras.src.engine import functional
137 from keras.src.engine import sequential
138 from keras.src.layers import Wrapper
140 if not check_pydot():
141 raise ImportError(
142 "You must install pydot (`pip install pydot`) for "
143 "model_to_dot to work."
144 )
146 if subgraph:
147 dot = pydot.Cluster(style="dashed", graph_name=model.name)
148 dot.set("label", model.name)
149 dot.set("labeljust", "l")
150 else:
151 dot = pydot.Dot()
152 dot.set("rankdir", rankdir)
153 dot.set("concentrate", True)
154 dot.set("dpi", dpi)
155 dot.set_node_defaults(shape="record")
157 if layer_range is not None:
158 if len(layer_range) != 2:
159 raise ValueError(
160 "layer_range must be of shape (2,). Received: "
161 f"layer_range = {layer_range} of length {len(layer_range)}"
162 )
163 if not isinstance(layer_range[0], str) or not isinstance(
164 layer_range[1], str
165 ):
166 raise ValueError(
167 "layer_range should contain string type only. "
168 f"Received: {layer_range}"
169 )
170 layer_range = layer_utils.get_layer_index_bound_by_layer_name(
171 model, layer_range
172 )
173 if layer_range[0] < 0 or layer_range[1] > len(model.layers):
174 raise ValueError(
175 "Both values in layer_range should be in range (0, "
176 f"{len(model.layers)}. Received: {layer_range}"
177 )
179 sub_n_first_node = {}
180 sub_n_last_node = {}
181 sub_w_first_node = {}
182 sub_w_last_node = {}
184 layers = model.layers
185 if not model._is_graph_network:
186 node = pydot.Node(str(id(model)), label=model.name)
187 dot.add_node(node)
188 return dot
189 elif isinstance(model, sequential.Sequential):
190 if not model.built:
191 model.build()
192 layers = super(sequential.Sequential, model).layers
194 # Create graph nodes.
195 for i, layer in enumerate(layers):
196 if (layer_range) and (i < layer_range[0] or i >= layer_range[1]):
197 continue
199 layer_id = str(id(layer))
201 # Append a wrapped layer's label to node's label, if it exists.
202 layer_name = layer.name
203 class_name = layer.__class__.__name__
205 if isinstance(layer, Wrapper):
206 if expand_nested and isinstance(layer.layer, functional.Functional):
207 submodel_wrapper = model_to_dot(
208 layer.layer,
209 show_shapes,
210 show_dtype,
211 show_layer_names,
212 rankdir,
213 expand_nested,
214 subgraph=True,
215 show_layer_activations=show_layer_activations,
216 show_trainable=show_trainable,
217 )
218 # sub_w : submodel_wrapper
219 sub_w_nodes = submodel_wrapper.get_nodes()
220 sub_w_first_node[layer.layer.name] = sub_w_nodes[0]
221 sub_w_last_node[layer.layer.name] = sub_w_nodes[-1]
222 dot.add_subgraph(submodel_wrapper)
223 else:
224 layer_name = f"{layer_name}({layer.layer.name})"
225 child_class_name = layer.layer.__class__.__name__
226 class_name = f"{class_name}({child_class_name})"
228 if expand_nested and isinstance(layer, functional.Functional):
229 submodel_not_wrapper = model_to_dot(
230 layer,
231 show_shapes,
232 show_dtype,
233 show_layer_names,
234 rankdir,
235 expand_nested,
236 subgraph=True,
237 show_layer_activations=show_layer_activations,
238 show_trainable=show_trainable,
239 )
240 # sub_n : submodel_not_wrapper
241 sub_n_nodes = submodel_not_wrapper.get_nodes()
242 sub_n_first_node[layer.name] = sub_n_nodes[0]
243 sub_n_last_node[layer.name] = sub_n_nodes[-1]
244 dot.add_subgraph(submodel_not_wrapper)
246 # Create node's label.
247 label = class_name
249 # Rebuild the label as a table including the layer's activation.
250 if (
251 show_layer_activations
252 and hasattr(layer, "activation")
253 and layer.activation is not None
254 ):
255 if hasattr(layer.activation, "name"):
256 activation_name = layer.activation.name
257 elif hasattr(layer.activation, "__name__"):
258 activation_name = layer.activation.__name__
259 else:
260 activation_name = str(layer.activation)
261 label = "{%s|%s}" % (label, activation_name)
263 # Rebuild the label as a table including the layer's name.
264 if show_layer_names:
265 label = f"{layer_name}|{label}"
267 # Rebuild the label as a table including the layer's dtype.
268 if show_dtype:
270 def format_dtype(dtype):
271 if dtype is None:
272 return "?"
273 else:
274 return str(dtype)
276 label = f"{label}|{format_dtype(layer.dtype)}"
278 # Rebuild the label as a table including input/output shapes.
279 if show_shapes:
281 def format_shape(shape):
282 return (
283 str(shape)
284 .replace(str(None), "None")
285 .replace("{", r"\{")
286 .replace("}", r"\}")
287 )
289 try:
290 outputlabels = format_shape(layer.output_shape)
291 except AttributeError:
292 outputlabels = "?"
293 if hasattr(layer, "input_shape"):
294 inputlabels = format_shape(layer.input_shape)
295 elif hasattr(layer, "input_shapes"):
296 inputlabels = ", ".join(
297 [format_shape(ishape) for ishape in layer.input_shapes]
298 )
299 else:
300 inputlabels = "?"
301 label = "{%s}|{input:|output:}|{{%s}|{%s}}" % (
302 label,
303 inputlabels,
304 outputlabels,
305 )
307 # Rebuild the label as a table including trainable status
308 if show_trainable:
309 label = f"{'T' if layer.trainable else 'NT'}|{label}"
311 if not expand_nested or not isinstance(layer, functional.Functional):
312 node = pydot.Node(layer_id, label=label)
313 dot.add_node(node)
315 # Connect nodes with edges.
316 for i, layer in enumerate(layers):
317 if (layer_range) and (i <= layer_range[0] or i >= layer_range[1]):
318 continue
319 layer_id = str(id(layer))
320 for i, node in enumerate(layer._inbound_nodes):
321 node_key = layer.name + "_ib-" + str(i)
322 if node_key in model._network_nodes:
323 for inbound_layer in tf.nest.flatten(node.inbound_layers):
324 inbound_layer_id = str(id(inbound_layer))
325 if not expand_nested:
326 assert dot.get_node(inbound_layer_id)
327 assert dot.get_node(layer_id)
328 add_edge(dot, inbound_layer_id, layer_id)
329 else:
330 # if inbound_layer is not Model or wrapped Model
331 if not isinstance(
332 inbound_layer, functional.Functional
333 ) and not is_wrapped_model(inbound_layer):
334 # if current layer is not Model or wrapped Model
335 if not isinstance(
336 layer, functional.Functional
337 ) and not is_wrapped_model(layer):
338 assert dot.get_node(inbound_layer_id)
339 assert dot.get_node(layer_id)
340 add_edge(dot, inbound_layer_id, layer_id)
341 # if current layer is Model
342 elif isinstance(layer, functional.Functional):
343 add_edge(
344 dot,
345 inbound_layer_id,
346 sub_n_first_node[layer.name].get_name(),
347 )
348 # if current layer is wrapped Model
349 elif is_wrapped_model(layer):
350 add_edge(dot, inbound_layer_id, layer_id)
351 name = sub_w_first_node[
352 layer.layer.name
353 ].get_name()
354 add_edge(dot, layer_id, name)
355 # if inbound_layer is Model
356 elif isinstance(inbound_layer, functional.Functional):
357 name = sub_n_last_node[
358 inbound_layer.name
359 ].get_name()
360 if isinstance(layer, functional.Functional):
361 output_name = sub_n_first_node[
362 layer.name
363 ].get_name()
364 add_edge(dot, name, output_name)
365 else:
366 add_edge(dot, name, layer_id)
367 # if inbound_layer is wrapped Model
368 elif is_wrapped_model(inbound_layer):
369 inbound_layer_name = inbound_layer.layer.name
370 add_edge(
371 dot,
372 sub_w_last_node[inbound_layer_name].get_name(),
373 layer_id,
374 )
375 return dot
378@keras_export("keras.utils.plot_model")
379def plot_model(
380 model,
381 to_file="model.png",
382 show_shapes=False,
383 show_dtype=False,
384 show_layer_names=True,
385 rankdir="TB",
386 expand_nested=False,
387 dpi=96,
388 layer_range=None,
389 show_layer_activations=False,
390 show_trainable=False,
391):
392 """Converts a Keras model to dot format and save to a file.
394 Example:
396 ```python
397 input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
398 x = tf.keras.layers.Embedding(
399 output_dim=512, input_dim=10000, input_length=100)(input)
400 x = tf.keras.layers.LSTM(32)(x)
401 x = tf.keras.layers.Dense(64, activation='relu')(x)
402 x = tf.keras.layers.Dense(64, activation='relu')(x)
403 x = tf.keras.layers.Dense(64, activation='relu')(x)
404 output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
405 model = tf.keras.Model(inputs=[input], outputs=[output])
406 dot_img_file = '/tmp/model_1.png'
407 tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
408 ```
410 Args:
411 model: A Keras model instance
412 to_file: File name of the plot image.
413 show_shapes: whether to display shape information.
414 show_dtype: whether to display layer dtypes.
415 show_layer_names: whether to display layer names.
416 rankdir: `rankdir` argument passed to PyDot,
417 a string specifying the format of the plot: 'TB' creates a vertical
418 plot; 'LR' creates a horizontal plot.
419 expand_nested: Whether to expand nested models into clusters.
420 dpi: Dots per inch.
421 layer_range: input of `list` containing two `str` items, which is the
422 starting layer name and ending layer name (both inclusive) indicating
423 the range of layers for which the plot will be generated. It also
424 accepts regex patterns instead of exact name. In such case, start
425 predicate will be the first element it matches to `layer_range[0]` and
426 the end predicate will be the last element it matches to
427 `layer_range[1]`. By default `None` which considers all layers of model.
428 Note that you must pass range such that the resultant subgraph must be
429 complete.
430 show_layer_activations: Display layer activations (only for layers that
431 have an `activation` property).
432 show_trainable: whether to display if a layer is trainable. Displays 'T'
433 when the layer is trainable and 'NT' when it is not trainable.
435 Raises:
436 ImportError: if graphviz or pydot are not available.
437 ValueError: if `plot_model` is called before the model is built.
439 Returns:
440 A Jupyter notebook Image object if Jupyter is installed.
441 This enables in-line display of the model plots in notebooks.
442 """
444 if not model.built:
445 raise ValueError(
446 "This model has not yet been built. "
447 "Build the model first by calling `build()` or by calling "
448 "the model on a batch of data."
449 )
451 if not check_graphviz():
452 message = (
453 "You must install pydot (`pip install pydot`) "
454 "and install graphviz "
455 "(see instructions at https://graphviz.gitlab.io/download/) "
456 "for plot_model to work."
457 )
458 if "IPython.core.magics.namespace" in sys.modules:
459 # We don't raise an exception here in order to avoid crashing
460 # notebook tests where graphviz is not available.
461 io_utils.print_msg(message)
462 return
463 else:
464 raise ImportError(message)
466 dot = model_to_dot(
467 model,
468 show_shapes=show_shapes,
469 show_dtype=show_dtype,
470 show_layer_names=show_layer_names,
471 rankdir=rankdir,
472 expand_nested=expand_nested,
473 dpi=dpi,
474 layer_range=layer_range,
475 show_layer_activations=show_layer_activations,
476 show_trainable=show_trainable,
477 )
478 to_file = io_utils.path_to_string(to_file)
479 if dot is None:
480 return
481 _, extension = os.path.splitext(to_file)
482 if not extension:
483 extension = "png"
484 else:
485 extension = extension[1:]
486 # Save image to disk.
487 dot.write(to_file, format=extension)
488 # Return the image as a Jupyter Image object, to be displayed in-line.
489 # Note that we cannot easily detect whether the code is running in a
490 # notebook, and thus we always return the Image if Jupyter is available.
491 if extension != "pdf":
492 try:
493 from IPython import display
495 return display.Image(filename=to_file)
496 except ImportError:
497 pass