Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/graph_io.py: 44%
25 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Utility functions for reading/writing graphs."""
17import os
18import os.path
19import sys
21from google.protobuf import text_format
22from tensorflow.python.framework import byte_swap_tensor
23from tensorflow.python.framework import ops
24from tensorflow.python.lib.io import file_io
25from tensorflow.python.util.tf_export import tf_export
28@tf_export('io.write_graph', v1=['io.write_graph', 'train.write_graph'])
29def write_graph(graph_or_graph_def, logdir, name, as_text=True):
30 """Writes a graph proto to a file.
32 The graph is written as a text proto unless `as_text` is `False`.
34 ```python
35 v = tf.Variable(0, name='my_variable')
36 sess = tf.compat.v1.Session()
37 tf.io.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt')
38 ```
40 or
42 ```python
43 v = tf.Variable(0, name='my_variable')
44 sess = tf.compat.v1.Session()
45 tf.io.write_graph(sess.graph, '/tmp/my-model', 'train.pbtxt')
46 ```
48 Args:
49 graph_or_graph_def: A `Graph` or a `GraphDef` protocol buffer.
50 logdir: Directory where to write the graph. This can refer to remote
51 filesystems, such as Google Cloud Storage (GCS).
52 name: Filename for the graph.
53 as_text: If `True`, writes the graph as an ASCII proto.
55 Returns:
56 The path of the output proto file.
57 """
58 if isinstance(graph_or_graph_def, ops.Graph):
59 graph_def = graph_or_graph_def.as_graph_def()
60 else:
61 graph_def = graph_or_graph_def
63 if sys.byteorder == 'big':
64 if hasattr(graph_def, 'node'):
65 byte_swap_tensor.swap_tensor_content_in_graph_node(
66 graph_def, 'big', 'little'
67 )
68 else:
69 byte_swap_tensor.swap_tensor_content_in_graph_function(
70 graph_def, 'big', 'little'
71 )
73 # gcs does not have the concept of directory at the moment.
74 if not logdir.startswith('gs:'):
75 file_io.recursive_create_dir(logdir)
76 path = os.path.join(logdir, name)
77 if as_text:
78 file_io.atomic_write_string_to_file(path,
79 text_format.MessageToString(
80 graph_def, float_format=''))
81 else:
82 file_io.atomic_write_string_to_file(
83 path, graph_def.SerializeToString(deterministic=True))
84 return path