Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/grappler/cluster.py: 45%

44 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A python interface for Grappler clusters.""" 

16 

17import contextlib 

18 

19from tensorflow.core.framework import step_stats_pb2 

20from tensorflow.core.grappler.costs import op_performance_data_pb2 

21from tensorflow.core.protobuf import device_properties_pb2 

22from tensorflow.python.grappler import _pywrap_tf_cluster as tf_cluster 

23 

24 

25class Cluster(object): 

26 """Grappler Clusters.""" 

27 

28 def __init__(self, 

29 allow_soft_placement=True, 

30 disable_detailed_stats=True, 

31 disable_timeline=True, 

32 devices=None): 

33 """Creates a Cluster. 

34 

35 Args: 

36 allow_soft_placement: If True, TF will automatically fix illegal 

37 placements instead of erroring out if the placement isn't legal. 

38 disable_detailed_stats: If True, detailed statistics will not be 

39 available. 

40 disable_timeline: If True, the timeline information will not be reported. 

41 devices: A list of devices of type device_properties_pb2.NamedDevice. 

42 If None, a device list will be created based on the spec of 

43 the local machine. 

44 """ 

45 self._tf_cluster = None 

46 self._generate_timeline = not disable_timeline 

47 

48 if devices is None: 

49 self._tf_cluster = tf_cluster.TF_NewCluster(allow_soft_placement, 

50 disable_detailed_stats) 

51 else: 

52 devices_serialized = [device.SerializeToString() for device in devices] 

53 self._tf_cluster = tf_cluster.TF_NewVirtualCluster(devices_serialized) 

54 

55 def Shutdown(self): 

56 if self._tf_cluster is not None: 

57 tf_cluster.TF_ShutdownCluster(self._tf_cluster) 

58 self._tf_cluster = None 

59 

60 def __del__(self): 

61 self.Shutdown() 

62 

63 @property 

64 def tf_cluster(self): 

65 return self._tf_cluster 

66 

67 def ListDevices(self): 

68 """Returns a list of available hardware devices.""" 

69 if self._tf_cluster is None: 

70 return [] 

71 return [device_properties_pb2.NamedDevice.FromString(device) 

72 for device in tf_cluster.TF_ListDevices(self._tf_cluster)] 

73 

74 def ListAvailableOps(self): 

75 """Returns a list of all available operations (sorted alphabetically).""" 

76 return tf_cluster.TF_ListAvailableOps() 

77 

78 def GetSupportedDevices(self, item): 

79 return tf_cluster.TF_GetSupportedDevices(self._tf_cluster, item.tf_item) 

80 

81 def EstimatePerformance(self, device): 

82 return tf_cluster.TF_EstimatePerformance(device.SerializeToString()) 

83 

84 def MeasureCosts(self, item): 

85 """Returns the cost of running the specified item. 

86 

87 Args: 

88 item: The item for which to measure the costs. 

89 Returns: The triplet op_perfs, runtime, step_stats. 

90 """ 

91 op_perf_bytes_list, run_time, step_stats_bytes = tf_cluster.TF_MeasureCosts( 

92 item.tf_item, self._tf_cluster, self._generate_timeline) 

93 

94 op_perfs = [op_performance_data_pb2.OpPerformance.FromString(op_perf_bytes) 

95 for op_perf_bytes in op_perf_bytes_list] 

96 return (op_perfs, run_time, 

97 step_stats_pb2.StepStats.FromString(step_stats_bytes)) 

98 

99 def DeterminePeakMemoryUsage(self, item): 

100 """Returns a snapshot of the peak memory usage. 

101 

102 Args: 

103 item: The item for which to measure the costs. 

104 Returns: A hashtable indexed by device name. 

105 """ 

106 return tf_cluster.TF_DeterminePeakMemoryUsage(item.tf_item, 

107 self._tf_cluster) 

108 

109 

110@contextlib.contextmanager 

111def Provision(allow_soft_placement=True, 

112 disable_detailed_stats=True, 

113 disable_timeline=True, 

114 devices=None): 

115 cluster = Cluster(allow_soft_placement, disable_detailed_stats, 

116 disable_timeline, devices) 

117 yield cluster 

118 cluster.Shutdown()