Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/graph_io.py: 44%

25 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Utility functions for reading/writing graphs.""" 

17import os 

18import os.path 

19import sys 

20 

21from google.protobuf import text_format 

22from tensorflow.python.framework import byte_swap_tensor 

23from tensorflow.python.framework import ops 

24from tensorflow.python.lib.io import file_io 

25from tensorflow.python.util.tf_export import tf_export 

26 

27 

28@tf_export('io.write_graph', v1=['io.write_graph', 'train.write_graph']) 

29def write_graph(graph_or_graph_def, logdir, name, as_text=True): 

30 """Writes a graph proto to a file. 

31 

32 The graph is written as a text proto unless `as_text` is `False`. 

33 

34 ```python 

35 v = tf.Variable(0, name='my_variable') 

36 sess = tf.compat.v1.Session() 

37 tf.io.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt') 

38 ``` 

39 

40 or 

41 

42 ```python 

43 v = tf.Variable(0, name='my_variable') 

44 sess = tf.compat.v1.Session() 

45 tf.io.write_graph(sess.graph, '/tmp/my-model', 'train.pbtxt') 

46 ``` 

47 

48 Args: 

49 graph_or_graph_def: A `Graph` or a `GraphDef` protocol buffer. 

50 logdir: Directory where to write the graph. This can refer to remote 

51 filesystems, such as Google Cloud Storage (GCS). 

52 name: Filename for the graph. 

53 as_text: If `True`, writes the graph as an ASCII proto. 

54 

55 Returns: 

56 The path of the output proto file. 

57 """ 

58 if isinstance(graph_or_graph_def, ops.Graph): 

59 graph_def = graph_or_graph_def.as_graph_def() 

60 else: 

61 graph_def = graph_or_graph_def 

62 

63 if sys.byteorder == 'big': 

64 if hasattr(graph_def, 'node'): 

65 byte_swap_tensor.swap_tensor_content_in_graph_node( 

66 graph_def, 'big', 'little' 

67 ) 

68 else: 

69 byte_swap_tensor.swap_tensor_content_in_graph_function( 

70 graph_def, 'big', 'little' 

71 ) 

72 

73 # gcs does not have the concept of directory at the moment. 

74 if not logdir.startswith('gs:'): 

75 file_io.recursive_create_dir(logdir) 

76 path = os.path.join(logdir, name) 

77 if as_text: 

78 file_io.atomic_write_string_to_file(path, 

79 text_format.MessageToString( 

80 graph_def, float_format='')) 

81 else: 

82 file_io.atomic_write_string_to_file( 

83 path, graph_def.SerializeToString(deterministic=True)) 

84 return path