1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of ClusterResolvers for GCE instance groups."""
16
17from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
18from tensorflow.python.training.server_lib import ClusterSpec
19from tensorflow.python.util.tf_export import tf_export
20
21
22_GOOGLE_API_CLIENT_INSTALLED = True
23try:
24 from googleapiclient import discovery # pylint: disable=g-import-not-at-top
25 from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
26except ImportError:
27 _GOOGLE_API_CLIENT_INSTALLED = False
28
29
30@tf_export('distribute.cluster_resolver.GCEClusterResolver')
31class GCEClusterResolver(ClusterResolver):
32 """ClusterResolver for Google Compute Engine.
33
34 This is an implementation of cluster resolvers for the Google Compute Engine
35 instance group platform. By specifying a project, zone, and instance group,
36 this will retrieve the IP address of all the instances within the instance
37 group and return a ClusterResolver object suitable for use for distributed
38 TensorFlow.
39
40 Note: this cluster resolver cannot retrieve `task_type`, `task_id` or
41 `rpc_layer`. To use it with some distribution strategies like
42 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, you will need to
43 specify `task_type` and `task_id` in the constructor.
44
45 Usage example with tf.distribute.Strategy:
46
47 ```Python
48 # On worker 0
49 cluster_resolver = GCEClusterResolver("my-project", "us-west1",
50 "my-instance-group",
51 task_type="worker", task_id=0)
52 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
53 cluster_resolver=cluster_resolver)
54
55 # On worker 1
56 cluster_resolver = GCEClusterResolver("my-project", "us-west1",
57 "my-instance-group",
58 task_type="worker", task_id=1)
59 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
60 cluster_resolver=cluster_resolver)
61 ```
62 """
63
64 def __init__(self,
65 project,
66 zone,
67 instance_group,
68 port,
69 task_type='worker',
70 task_id=0,
71 rpc_layer='grpc',
72 credentials='default',
73 service=None):
74 """Creates a new GCEClusterResolver object.
75
76 This takes in a few parameters and creates a GCEClusterResolver project. It
77 will then use these parameters to query the GCE API for the IP addresses of
78 each instance in the instance group.
79
80 Args:
81 project: Name of the GCE project.
82 zone: Zone of the GCE instance group.
83 instance_group: Name of the GCE instance group.
84 port: Port of the listening TensorFlow server (default: 8470)
85 task_type: Name of the TensorFlow job this GCE instance group of VM
86 instances belong to.
87 task_id: The task index for this particular VM, within the GCE
88 instance group. In particular, every single instance should be assigned
89 a unique ordinal index within an instance group manually so that they
90 can be distinguished from each other.
91 rpc_layer: The RPC layer TensorFlow should use to communicate across
92 instances.
93 credentials: GCE Credentials. If nothing is specified, this defaults to
94 GoogleCredentials.get_application_default().
95 service: The GCE API object returned by the googleapiclient.discovery
96 function. (Default: discovery.build('compute', 'v1')). If you specify a
97 custom service object, then the credentials parameter will be ignored.
98
99 Raises:
100 ImportError: If the googleapiclient is not installed.
101 """
102 self._project = project
103 self._zone = zone
104 self._instance_group = instance_group
105 self._task_type = task_type
106 self._task_id = task_id
107 self._rpc_layer = rpc_layer
108 self._port = port
109 self._credentials = credentials
110
111 if credentials == 'default':
112 if _GOOGLE_API_CLIENT_INSTALLED:
113 self._credentials = GoogleCredentials.get_application_default()
114
115 if service is None:
116 if not _GOOGLE_API_CLIENT_INSTALLED:
117 raise ImportError('googleapiclient must be installed before using the '
118 'GCE cluster resolver')
119 self._service = discovery.build(
120 'compute', 'v1',
121 credentials=self._credentials)
122 else:
123 self._service = service
124
125 def cluster_spec(self):
126 """Returns a ClusterSpec object based on the latest instance group info.
127
128 This returns a ClusterSpec object for use based on information from the
129 specified instance group. We will retrieve the information from the GCE APIs
130 every time this method is called.
131
132 Returns:
133 A ClusterSpec containing host information retrieved from GCE.
134 """
135 request_body = {'instanceState': 'RUNNING'}
136 request = self._service.instanceGroups().listInstances(
137 project=self._project,
138 zone=self._zone,
139 instanceGroups=self._instance_group,
140 body=request_body,
141 orderBy='name')
142
143 worker_list = []
144
145 while request is not None:
146 response = request.execute()
147
148 items = response['items']
149 for instance in items:
150 instance_name = instance['instance'].split('/')[-1]
151
152 instance_request = self._service.instances().get(
153 project=self._project,
154 zone=self._zone,
155 instance=instance_name)
156
157 if instance_request is not None:
158 instance_details = instance_request.execute()
159 ip_address = instance_details['networkInterfaces'][0]['networkIP']
160 instance_url = '%s:%s' % (ip_address, self._port)
161 worker_list.append(instance_url)
162
163 request = self._service.instanceGroups().listInstances_next(
164 previous_request=request,
165 previous_response=response)
166
167 worker_list.sort()
168 return ClusterSpec({self._task_type: worker_list})
169
170 def master(self, task_type=None, task_id=None, rpc_layer=None):
171 task_type = task_type if task_type is not None else self._task_type
172 task_id = task_id if task_id is not None else self._task_id
173
174 if task_type is not None and task_id is not None:
175 master = self.cluster_spec().task_address(task_type, task_id)
176 if rpc_layer or self._rpc_layer:
177 return '%s://%s' % (rpc_layer or self._rpc_layer, master)
178 else:
179 return master
180
181 return ''
182
183 @property
184 def task_type(self):
185 return self._task_type
186
187 @property
188 def task_id(self):
189 return self._task_id
190
191 @task_type.setter
192 def task_type(self, task_type):
193 raise RuntimeError(
194 'You cannot reset the task_type of the GCEClusterResolver after it has '
195 'been created.')
196
197 @task_id.setter
198 def task_id(self, task_id):
199 self._task_id = task_id
200
201 @property
202 def rpc_layer(self):
203 return self._rpc_layer
204
205 @rpc_layer.setter
206 def rpc_layer(self, rpc_layer):
207 self._rpc_layer = rpc_layer