Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/grappler/cluster.py: 45%
44 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A python interface for Grappler clusters."""
17import contextlib
19from tensorflow.core.framework import step_stats_pb2
20from tensorflow.core.grappler.costs import op_performance_data_pb2
21from tensorflow.core.protobuf import device_properties_pb2
22from tensorflow.python.grappler import _pywrap_tf_cluster as tf_cluster
25class Cluster(object):
26 """Grappler Clusters."""
28 def __init__(self,
29 allow_soft_placement=True,
30 disable_detailed_stats=True,
31 disable_timeline=True,
32 devices=None):
33 """Creates a Cluster.
35 Args:
36 allow_soft_placement: If True, TF will automatically fix illegal
37 placements instead of erroring out if the placement isn't legal.
38 disable_detailed_stats: If True, detailed statistics will not be
39 available.
40 disable_timeline: If True, the timeline information will not be reported.
41 devices: A list of devices of type device_properties_pb2.NamedDevice.
42 If None, a device list will be created based on the spec of
43 the local machine.
44 """
45 self._tf_cluster = None
46 self._generate_timeline = not disable_timeline
48 if devices is None:
49 self._tf_cluster = tf_cluster.TF_NewCluster(allow_soft_placement,
50 disable_detailed_stats)
51 else:
52 devices_serialized = [device.SerializeToString() for device in devices]
53 self._tf_cluster = tf_cluster.TF_NewVirtualCluster(devices_serialized)
55 def Shutdown(self):
56 if self._tf_cluster is not None:
57 tf_cluster.TF_ShutdownCluster(self._tf_cluster)
58 self._tf_cluster = None
60 def __del__(self):
61 self.Shutdown()
63 @property
64 def tf_cluster(self):
65 return self._tf_cluster
67 def ListDevices(self):
68 """Returns a list of available hardware devices."""
69 if self._tf_cluster is None:
70 return []
71 return [device_properties_pb2.NamedDevice.FromString(device)
72 for device in tf_cluster.TF_ListDevices(self._tf_cluster)]
74 def ListAvailableOps(self):
75 """Returns a list of all available operations (sorted alphabetically)."""
76 return tf_cluster.TF_ListAvailableOps()
78 def GetSupportedDevices(self, item):
79 return tf_cluster.TF_GetSupportedDevices(self._tf_cluster, item.tf_item)
81 def EstimatePerformance(self, device):
82 return tf_cluster.TF_EstimatePerformance(device.SerializeToString())
84 def MeasureCosts(self, item):
85 """Returns the cost of running the specified item.
87 Args:
88 item: The item for which to measure the costs.
89 Returns: The triplet op_perfs, runtime, step_stats.
90 """
91 op_perf_bytes_list, run_time, step_stats_bytes = tf_cluster.TF_MeasureCosts(
92 item.tf_item, self._tf_cluster, self._generate_timeline)
94 op_perfs = [op_performance_data_pb2.OpPerformance.FromString(op_perf_bytes)
95 for op_perf_bytes in op_perf_bytes_list]
96 return (op_perfs, run_time,
97 step_stats_pb2.StepStats.FromString(step_stats_bytes))
99 def DeterminePeakMemoryUsage(self, item):
100 """Returns a snapshot of the peak memory usage.
102 Args:
103 item: The item for which to measure the costs.
104 Returns: A hashtable indexed by device name.
105 """
106 return tf_cluster.TF_DeterminePeakMemoryUsage(item.tf_item,
107 self._tf_cluster)
110@contextlib.contextmanager
111def Provision(allow_soft_placement=True,
112 disable_detailed_stats=True,
113 disable_timeline=True,
114 devices=None):
115 cluster = Cluster(allow_soft_placement, disable_detailed_stats,
116 disable_timeline, devices)
117 yield cluster
118 cluster.Shutdown()