Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/cluster_resolver/gce_cluster_resolver.py: 35%

74 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implementation of ClusterResolvers for GCE instance groups.""" 

16 

17from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver 

18from tensorflow.python.training.server_lib import ClusterSpec 

19from tensorflow.python.util.tf_export import tf_export 

20 

21 

22_GOOGLE_API_CLIENT_INSTALLED = True 

23try: 

24 from googleapiclient import discovery # pylint: disable=g-import-not-at-top 

25 from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top 

26except ImportError: 

27 _GOOGLE_API_CLIENT_INSTALLED = False 

28 

29 

30@tf_export('distribute.cluster_resolver.GCEClusterResolver') 

31class GCEClusterResolver(ClusterResolver): 

32 """ClusterResolver for Google Compute Engine. 

33 

34 This is an implementation of cluster resolvers for the Google Compute Engine 

35 instance group platform. By specifying a project, zone, and instance group, 

36 this will retrieve the IP address of all the instances within the instance 

37 group and return a ClusterResolver object suitable for use for distributed 

38 TensorFlow. 

39 

40 Note: this cluster resolver cannot retrieve `task_type`, `task_id` or 

41 `rpc_layer`. To use it with some distribution strategies like 

42 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, you will need to 

43 specify `task_type` and `task_id` in the constructor. 

44 

45 Usage example with tf.distribute.Strategy: 

46 

47 ```Python 

48 # On worker 0 

49 cluster_resolver = GCEClusterResolver("my-project", "us-west1", 

50 "my-instance-group", 

51 task_type="worker", task_id=0) 

52 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( 

53 cluster_resolver=cluster_resolver) 

54 

55 # On worker 1 

56 cluster_resolver = GCEClusterResolver("my-project", "us-west1", 

57 "my-instance-group", 

58 task_type="worker", task_id=1) 

59 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( 

60 cluster_resolver=cluster_resolver) 

61 ``` 

62 """ 

63 

64 def __init__(self, 

65 project, 

66 zone, 

67 instance_group, 

68 port, 

69 task_type='worker', 

70 task_id=0, 

71 rpc_layer='grpc', 

72 credentials='default', 

73 service=None): 

74 """Creates a new GCEClusterResolver object. 

75 

76 This takes in a few parameters and creates a GCEClusterResolver project. It 

77 will then use these parameters to query the GCE API for the IP addresses of 

78 each instance in the instance group. 

79 

80 Args: 

81 project: Name of the GCE project. 

82 zone: Zone of the GCE instance group. 

83 instance_group: Name of the GCE instance group. 

84 port: Port of the listening TensorFlow server (default: 8470) 

85 task_type: Name of the TensorFlow job this GCE instance group of VM 

86 instances belong to. 

87 task_id: The task index for this particular VM, within the GCE 

88 instance group. In particular, every single instance should be assigned 

89 a unique ordinal index within an instance group manually so that they 

90 can be distinguished from each other. 

91 rpc_layer: The RPC layer TensorFlow should use to communicate across 

92 instances. 

93 credentials: GCE Credentials. If nothing is specified, this defaults to 

94 GoogleCredentials.get_application_default(). 

95 service: The GCE API object returned by the googleapiclient.discovery 

96 function. (Default: discovery.build('compute', 'v1')). If you specify a 

97 custom service object, then the credentials parameter will be ignored. 

98 

99 Raises: 

100 ImportError: If the googleapiclient is not installed. 

101 """ 

102 self._project = project 

103 self._zone = zone 

104 self._instance_group = instance_group 

105 self._task_type = task_type 

106 self._task_id = task_id 

107 self._rpc_layer = rpc_layer 

108 self._port = port 

109 self._credentials = credentials 

110 

111 if credentials == 'default': 

112 if _GOOGLE_API_CLIENT_INSTALLED: 

113 self._credentials = GoogleCredentials.get_application_default() 

114 

115 if service is None: 

116 if not _GOOGLE_API_CLIENT_INSTALLED: 

117 raise ImportError('googleapiclient must be installed before using the ' 

118 'GCE cluster resolver') 

119 self._service = discovery.build( 

120 'compute', 'v1', 

121 credentials=self._credentials) 

122 else: 

123 self._service = service 

124 

125 def cluster_spec(self): 

126 """Returns a ClusterSpec object based on the latest instance group info. 

127 

128 This returns a ClusterSpec object for use based on information from the 

129 specified instance group. We will retrieve the information from the GCE APIs 

130 every time this method is called. 

131 

132 Returns: 

133 A ClusterSpec containing host information retrieved from GCE. 

134 """ 

135 request_body = {'instanceState': 'RUNNING'} 

136 request = self._service.instanceGroups().listInstances( 

137 project=self._project, 

138 zone=self._zone, 

139 instanceGroups=self._instance_group, 

140 body=request_body, 

141 orderBy='name') 

142 

143 worker_list = [] 

144 

145 while request is not None: 

146 response = request.execute() 

147 

148 items = response['items'] 

149 for instance in items: 

150 instance_name = instance['instance'].split('/')[-1] 

151 

152 instance_request = self._service.instances().get( 

153 project=self._project, 

154 zone=self._zone, 

155 instance=instance_name) 

156 

157 if instance_request is not None: 

158 instance_details = instance_request.execute() 

159 ip_address = instance_details['networkInterfaces'][0]['networkIP'] 

160 instance_url = '%s:%s' % (ip_address, self._port) 

161 worker_list.append(instance_url) 

162 

163 request = self._service.instanceGroups().listInstances_next( 

164 previous_request=request, 

165 previous_response=response) 

166 

167 worker_list.sort() 

168 return ClusterSpec({self._task_type: worker_list}) 

169 

170 def master(self, task_type=None, task_id=None, rpc_layer=None): 

171 task_type = task_type if task_type is not None else self._task_type 

172 task_id = task_id if task_id is not None else self._task_id 

173 

174 if task_type is not None and task_id is not None: 

175 master = self.cluster_spec().task_address(task_type, task_id) 

176 if rpc_layer or self._rpc_layer: 

177 return '%s://%s' % (rpc_layer or self._rpc_layer, master) 

178 else: 

179 return master 

180 

181 return '' 

182 

183 @property 

184 def task_type(self): 

185 return self._task_type 

186 

187 @property 

188 def task_id(self): 

189 return self._task_id 

190 

191 @task_type.setter 

192 def task_type(self, task_type): 

193 raise RuntimeError( 

194 'You cannot reset the task_type of the GCEClusterResolver after it has ' 

195 'been created.') 

196 

197 @task_id.setter 

198 def task_id(self, task_id): 

199 self._task_id = task_id 

200 

201 @property 

202 def rpc_layer(self): 

203 return self._rpc_layer 

204 

205 @rpc_layer.setter 

206 def rpc_layer(self, rpc_layer): 

207 self._rpc_layer = rpc_layer