Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/remote.py: 22%
97 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to connect to remote servers."""
17import copy
19from absl import logging
21from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
22from tensorflow.python import pywrap_tfe
23from tensorflow.python.distribute import device_util
24from tensorflow.python.distribute.cluster_resolver import cluster_resolver
25from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
26from tensorflow.python.eager import context
27from tensorflow.python.framework import ops
28from tensorflow.python.platform import remote_utils
29from tensorflow.python.training import server_lib
30from tensorflow.python.util import nest
31from tensorflow.python.util.tf_export import tf_export
34_GRPC_PREFIX = "grpc://"
35_LOCAL_MASTERS = ("", "local")
38@tf_export("config.experimental_connect_to_host")
39def connect_to_remote_host(remote_host=None, job_name="worker"):
40 """Connects to a single machine to enable remote execution on it.
42 Will make devices on the remote host available to use. Note that calling this
43 more than once will work, but will invalidate any tensor handles on the old
44 remote devices.
46 Using the default job_name of worker, you can schedule ops to run remotely as
47 follows:
48 ```python
49 # When eager execution is enabled, connect to the remote host.
50 tf.config.experimental_connect_to_host("exampleaddr.com:9876")
52 with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
53 # The following tensors should be resident on the remote device, and the op
54 # will also execute remotely.
55 x1 = array_ops.ones([2, 2])
56 x2 = array_ops.ones([2, 2])
57 y = math_ops.matmul(x1, x2)
58 ```
60 Args:
61 remote_host: a single or a list the remote server addr in host-port format.
62 job_name: The job name under which the new server will be accessible.
64 Raises:
65 ValueError: if remote_host is None.
66 """
67 if not remote_host:
68 raise ValueError("Must provide at least one remote_host")
70 remote_hosts = nest.flatten(remote_host)
71 cluster_spec = server_lib.ClusterSpec(
72 {job_name: [_strip_prefix(host, _GRPC_PREFIX) for host in remote_hosts]})
74 connect_to_cluster(cluster_spec)
77@tf_export("config.experimental_connect_to_cluster")
78def connect_to_cluster(cluster_spec_or_resolver,
79 job_name="localhost",
80 task_index=0,
81 protocol=None,
82 make_master_device_default=True,
83 cluster_device_filters=None):
84 """Connects to the given cluster.
86 Will make devices on the cluster available to use. Note that calling this more
87 than once will work, but will invalidate any tensor handles on the old remote
88 devices.
90 If the given local job name is not present in the cluster specification, it
91 will be automatically added, using an unused port on the localhost.
93 Device filters can be specified to isolate groups of remote tasks to avoid
94 undesired accesses between workers. Workers accessing resources or launching
95 ops / functions on filtered remote devices will result in errors (unknown
96 devices). For any remote task, if no device filter is present, all cluster
97 devices will be visible; if any device filter is specified, it can only
98 see devices matching at least one filter. Devices on the task itself are
99 always visible. Device filters can be particially specified.
101 For example, for a cluster set up for parameter server training, the following
102 device filters might be specified:
104 ```python
105 cdf = tf.config.experimental.ClusterDeviceFilters()
106 # For any worker, only the devices on PS nodes and itself are visible
107 for i in range(num_workers):
108 cdf.set_device_filters('worker', i, ['/job:ps'])
109 # Similarly for any ps, only the devices on workers and itself are visible
110 for i in range(num_ps):
111 cdf.set_device_filters('ps', i, ['/job:worker'])
113 tf.config.experimental_connect_to_cluster(cluster_def,
114 cluster_device_filters=cdf)
115 ```
117 Args:
118 cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing
119 the cluster.
120 job_name: The name of the local job.
121 task_index: The local task index.
122 protocol: The communication protocol, such as `"grpc"`. If unspecified, will
123 use the default from `python/platform/remote_utils.py`.
124 make_master_device_default: If True and a cluster resolver is passed, will
125 automatically enter the master task device scope, which indicates the
126 master becomes the default device to run ops. It won't do anything if
127 a cluster spec is passed. Will throw an error if the caller is currently
128 already in some device scope.
129 cluster_device_filters: an instance of
130 `tf.train.experimental/ClusterDeviceFilters` that specify device filters
131 to the remote tasks in cluster.
132 """
133 if not context.executing_eagerly():
134 raise ValueError(
135 "`tf.config.experimental_connect_to_cluster` can only be called in "
136 "eager mode."
137 )
138 protocol = protocol or remote_utils.get_default_communication_protocol()
139 if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
140 cluster_spec = cluster_spec_or_resolver
141 elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver):
142 if cluster_spec_or_resolver.master() in _LOCAL_MASTERS:
143 # Do nothing if the master is local.
144 return
145 cluster_spec = cluster_spec_or_resolver.cluster_spec()
146 else:
147 raise ValueError(
148 "`cluster_spec_or_resolver` must be a `ClusterSpec` or a "
149 "`ClusterResolver`.")
151 cluster_def = copy.deepcopy(cluster_spec.as_cluster_def())
152 if cluster_device_filters:
153 if isinstance(cluster_device_filters, server_lib.ClusterDeviceFilters):
154 cluster_device_filters = copy.deepcopy(
155 cluster_device_filters._as_cluster_device_filters()) # pylint: disable=protected-access
156 else:
157 raise ValueError("`cluster_device_filters` must be an instance of "
158 "`tf.train.experimental.ClusterDeviceFilters`.")
160 # Check whether the server def has changed. We need to do the check before the
161 # local job is added to the cluster.
162 is_server_def_changed = False
163 current_server_def = context.get_server_def()
164 if current_server_def and job_name not in cluster_spec.jobs:
165 for i, job in enumerate(current_server_def.cluster.job):
166 if job.name == job_name:
167 del current_server_def.cluster.job[i]
168 if (current_server_def is None or current_server_def.cluster != cluster_def or
169 current_server_def.job_name != job_name or
170 current_server_def.task_index != task_index):
171 is_server_def_changed = True
173 # Automatically add local job, if not part of the cluster spec.
174 if job_name not in cluster_spec.jobs:
175 local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
176 job_def = cluster_def.job.add()
177 job_def.name = job_name
178 # TODO(fishx): Update this to make sure remote worker has valid ip address
179 # to connect with local.
180 job_def.tasks[0] = "localhost:{}".format(local_port)
182 if context.context().coordination_service is None:
183 service_type = remote_utils.coordination_service_type(protocol)
184 service_leader = ""
185 # Maybe enable coordination service for the communication protocol
186 # TODO(b/243839559): Fix UPTC + Coordination service crashing
187 if isinstance(cluster_spec_or_resolver,
188 tpu_cluster_resolver.TPUClusterResolver):
189 is_uptc_sess = ".uptc-worker." in cluster_spec_or_resolver.master()
190 service_type = remote_utils.coordination_service_type(
191 protocol, is_uptc_sess)
192 service_leader = cluster_spec_or_resolver.get_coordination_service_leader(
193 )
194 if service_type:
195 # If `enable_health_check` is true, coordination service agent would
196 # do connecting (and tasks would send heartbeat if connection is set up)
197 # while creating eager contexts. Enabling health check does not mutate
198 # coordination service.
199 context.context().configure_coordination_service(
200 service_type=service_type,
201 service_leader=service_leader,
202 enable_health_check=False)
204 default_session_config = copy.deepcopy(context.context().config)
206 for name in cluster_spec.jobs:
207 # assuming any of the non-local job is the worker jobs.
208 # should we use cluster_spec_or_resolver.get_job_name() instead when
209 # it is available?
210 # maybe consolicate this with the 'master' logic below
211 if name == job_name:
212 continue
214 default_session_config.experimental.collective_group_leader = (
215 f"/job:{name}/replica:0/task:0"
216 )
218 logging.info("default session config: %s", default_session_config)
220 server_def = ServerDef(
221 cluster=cluster_def,
222 job_name=job_name,
223 task_index=task_index,
224 protocol=protocol,
225 default_session_config=default_session_config,
226 cluster_device_filters=cluster_device_filters,
227 )
229 if is_server_def_changed:
230 context.set_server_def(server_def)
231 else:
232 context.update_server_def(server_def)
234 if make_master_device_default and isinstance(
235 cluster_spec_or_resolver,
236 cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master():
237 master = cluster_spec_or_resolver.master()
238 master_job_name = None
239 master_task_id = None
240 for job_name in cluster_spec.jobs:
241 for task_id in cluster_spec.task_indices(job_name):
242 task_address = cluster_spec.task_address(job_name, task_id)
243 if master in task_address or task_address in master:
244 master_job_name = job_name
245 master_task_id = task_id
246 break
248 if not master_job_name:
249 raise ValueError(
250 "`make_master_device_default` is set to True but cannot find "
251 "master %s in the cluster" % master)
253 master_device = "/job:{}/replica:0/task:{}".format(master_job_name,
254 master_task_id)
255 master_device = device_util.canonicalize(master_device)
256 current_device = device_util.current()
257 if current_device:
258 current_device = device_util.canonicalize(current_device)
259 if current_device and current_device != master_device:
260 raise ValueError("`connect_to_cluster` is called inside existing device "
261 "scope %s, which is different from the master device "
262 "scope %s to enter. This is not allowed." %
263 (current_device, master_device))
264 # TODO(b/138389076): Think of the entering device scope behavior in the
265 # failure recovery case when dealing with preemptions.
266 if not current_device:
267 logging.info("Entering into master device scope: %s", master_device)
268 ops.device(master_device).__enter__()
271def _strip_prefix(s, prefix):
272 return s[len(prefix):] if s.startswith(prefix) else s