Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/remote.py: 22%

97 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Helpers to connect to remote servers.""" 

16 

17import copy 

18 

19from absl import logging 

20 

21from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef 

22from tensorflow.python import pywrap_tfe 

23from tensorflow.python.distribute import device_util 

24from tensorflow.python.distribute.cluster_resolver import cluster_resolver 

25from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 

26from tensorflow.python.eager import context 

27from tensorflow.python.framework import ops 

28from tensorflow.python.platform import remote_utils 

29from tensorflow.python.training import server_lib 

30from tensorflow.python.util import nest 

31from tensorflow.python.util.tf_export import tf_export 

32 

33 

34_GRPC_PREFIX = "grpc://" 

35_LOCAL_MASTERS = ("", "local") 

36 

37 

38@tf_export("config.experimental_connect_to_host") 

39def connect_to_remote_host(remote_host=None, job_name="worker"): 

40 """Connects to a single machine to enable remote execution on it. 

41 

42 Will make devices on the remote host available to use. Note that calling this 

43 more than once will work, but will invalidate any tensor handles on the old 

44 remote devices. 

45 

46 Using the default job_name of worker, you can schedule ops to run remotely as 

47 follows: 

48 ```python 

49 # When eager execution is enabled, connect to the remote host. 

50 tf.config.experimental_connect_to_host("exampleaddr.com:9876") 

51 

52 with ops.device("job:worker/replica:0/task:1/device:CPU:0"): 

53 # The following tensors should be resident on the remote device, and the op 

54 # will also execute remotely. 

55 x1 = array_ops.ones([2, 2]) 

56 x2 = array_ops.ones([2, 2]) 

57 y = math_ops.matmul(x1, x2) 

58 ``` 

59 

60 Args: 

61 remote_host: a single or a list the remote server addr in host-port format. 

62 job_name: The job name under which the new server will be accessible. 

63 

64 Raises: 

65 ValueError: if remote_host is None. 

66 """ 

67 if not remote_host: 

68 raise ValueError("Must provide at least one remote_host") 

69 

70 remote_hosts = nest.flatten(remote_host) 

71 cluster_spec = server_lib.ClusterSpec( 

72 {job_name: [_strip_prefix(host, _GRPC_PREFIX) for host in remote_hosts]}) 

73 

74 connect_to_cluster(cluster_spec) 

75 

76 

77@tf_export("config.experimental_connect_to_cluster") 

78def connect_to_cluster(cluster_spec_or_resolver, 

79 job_name="localhost", 

80 task_index=0, 

81 protocol=None, 

82 make_master_device_default=True, 

83 cluster_device_filters=None): 

84 """Connects to the given cluster. 

85 

86 Will make devices on the cluster available to use. Note that calling this more 

87 than once will work, but will invalidate any tensor handles on the old remote 

88 devices. 

89 

90 If the given local job name is not present in the cluster specification, it 

91 will be automatically added, using an unused port on the localhost. 

92 

93 Device filters can be specified to isolate groups of remote tasks to avoid 

94 undesired accesses between workers. Workers accessing resources or launching 

95 ops / functions on filtered remote devices will result in errors (unknown 

96 devices). For any remote task, if no device filter is present, all cluster 

97 devices will be visible; if any device filter is specified, it can only 

98 see devices matching at least one filter. Devices on the task itself are 

99 always visible. Device filters can be particially specified. 

100 

101 For example, for a cluster set up for parameter server training, the following 

102 device filters might be specified: 

103 

104 ```python 

105 cdf = tf.config.experimental.ClusterDeviceFilters() 

106 # For any worker, only the devices on PS nodes and itself are visible 

107 for i in range(num_workers): 

108 cdf.set_device_filters('worker', i, ['/job:ps']) 

109 # Similarly for any ps, only the devices on workers and itself are visible 

110 for i in range(num_ps): 

111 cdf.set_device_filters('ps', i, ['/job:worker']) 

112 

113 tf.config.experimental_connect_to_cluster(cluster_def, 

114 cluster_device_filters=cdf) 

115 ``` 

116 

117 Args: 

118 cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing 

119 the cluster. 

120 job_name: The name of the local job. 

121 task_index: The local task index. 

122 protocol: The communication protocol, such as `"grpc"`. If unspecified, will 

123 use the default from `python/platform/remote_utils.py`. 

124 make_master_device_default: If True and a cluster resolver is passed, will 

125 automatically enter the master task device scope, which indicates the 

126 master becomes the default device to run ops. It won't do anything if 

127 a cluster spec is passed. Will throw an error if the caller is currently 

128 already in some device scope. 

129 cluster_device_filters: an instance of 

130 `tf.train.experimental/ClusterDeviceFilters` that specify device filters 

131 to the remote tasks in cluster. 

132 """ 

133 if not context.executing_eagerly(): 

134 raise ValueError( 

135 "`tf.config.experimental_connect_to_cluster` can only be called in " 

136 "eager mode." 

137 ) 

138 protocol = protocol or remote_utils.get_default_communication_protocol() 

139 if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec): 

140 cluster_spec = cluster_spec_or_resolver 

141 elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver): 

142 if cluster_spec_or_resolver.master() in _LOCAL_MASTERS: 

143 # Do nothing if the master is local. 

144 return 

145 cluster_spec = cluster_spec_or_resolver.cluster_spec() 

146 else: 

147 raise ValueError( 

148 "`cluster_spec_or_resolver` must be a `ClusterSpec` or a " 

149 "`ClusterResolver`.") 

150 

151 cluster_def = copy.deepcopy(cluster_spec.as_cluster_def()) 

152 if cluster_device_filters: 

153 if isinstance(cluster_device_filters, server_lib.ClusterDeviceFilters): 

154 cluster_device_filters = copy.deepcopy( 

155 cluster_device_filters._as_cluster_device_filters()) # pylint: disable=protected-access 

156 else: 

157 raise ValueError("`cluster_device_filters` must be an instance of " 

158 "`tf.train.experimental.ClusterDeviceFilters`.") 

159 

160 # Check whether the server def has changed. We need to do the check before the 

161 # local job is added to the cluster. 

162 is_server_def_changed = False 

163 current_server_def = context.get_server_def() 

164 if current_server_def and job_name not in cluster_spec.jobs: 

165 for i, job in enumerate(current_server_def.cluster.job): 

166 if job.name == job_name: 

167 del current_server_def.cluster.job[i] 

168 if (current_server_def is None or current_server_def.cluster != cluster_def or 

169 current_server_def.job_name != job_name or 

170 current_server_def.task_index != task_index): 

171 is_server_def_changed = True 

172 

173 # Automatically add local job, if not part of the cluster spec. 

174 if job_name not in cluster_spec.jobs: 

175 local_port = pywrap_tfe.TF_PickUnusedPortOrDie() 

176 job_def = cluster_def.job.add() 

177 job_def.name = job_name 

178 # TODO(fishx): Update this to make sure remote worker has valid ip address 

179 # to connect with local. 

180 job_def.tasks[0] = "localhost:{}".format(local_port) 

181 

182 if context.context().coordination_service is None: 

183 service_type = remote_utils.coordination_service_type(protocol) 

184 service_leader = "" 

185 # Maybe enable coordination service for the communication protocol 

186 # TODO(b/243839559): Fix UPTC + Coordination service crashing 

187 if isinstance(cluster_spec_or_resolver, 

188 tpu_cluster_resolver.TPUClusterResolver): 

189 is_uptc_sess = ".uptc-worker." in cluster_spec_or_resolver.master() 

190 service_type = remote_utils.coordination_service_type( 

191 protocol, is_uptc_sess) 

192 service_leader = cluster_spec_or_resolver.get_coordination_service_leader( 

193 ) 

194 if service_type: 

195 # If `enable_health_check` is true, coordination service agent would 

196 # do connecting (and tasks would send heartbeat if connection is set up) 

197 # while creating eager contexts. Enabling health check does not mutate 

198 # coordination service. 

199 context.context().configure_coordination_service( 

200 service_type=service_type, 

201 service_leader=service_leader, 

202 enable_health_check=False) 

203 

204 default_session_config = copy.deepcopy(context.context().config) 

205 

206 for name in cluster_spec.jobs: 

207 # assuming any of the non-local job is the worker jobs. 

208 # should we use cluster_spec_or_resolver.get_job_name() instead when 

209 # it is available? 

210 # maybe consolicate this with the 'master' logic below 

211 if name == job_name: 

212 continue 

213 

214 default_session_config.experimental.collective_group_leader = ( 

215 f"/job:{name}/replica:0/task:0" 

216 ) 

217 

218 logging.info("default session config: %s", default_session_config) 

219 

220 server_def = ServerDef( 

221 cluster=cluster_def, 

222 job_name=job_name, 

223 task_index=task_index, 

224 protocol=protocol, 

225 default_session_config=default_session_config, 

226 cluster_device_filters=cluster_device_filters, 

227 ) 

228 

229 if is_server_def_changed: 

230 context.set_server_def(server_def) 

231 else: 

232 context.update_server_def(server_def) 

233 

234 if make_master_device_default and isinstance( 

235 cluster_spec_or_resolver, 

236 cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master(): 

237 master = cluster_spec_or_resolver.master() 

238 master_job_name = None 

239 master_task_id = None 

240 for job_name in cluster_spec.jobs: 

241 for task_id in cluster_spec.task_indices(job_name): 

242 task_address = cluster_spec.task_address(job_name, task_id) 

243 if master in task_address or task_address in master: 

244 master_job_name = job_name 

245 master_task_id = task_id 

246 break 

247 

248 if not master_job_name: 

249 raise ValueError( 

250 "`make_master_device_default` is set to True but cannot find " 

251 "master %s in the cluster" % master) 

252 

253 master_device = "/job:{}/replica:0/task:{}".format(master_job_name, 

254 master_task_id) 

255 master_device = device_util.canonicalize(master_device) 

256 current_device = device_util.current() 

257 if current_device: 

258 current_device = device_util.canonicalize(current_device) 

259 if current_device and current_device != master_device: 

260 raise ValueError("`connect_to_cluster` is called inside existing device " 

261 "scope %s, which is different from the master device " 

262 "scope %s to enter. This is not allowed." % 

263 (current_device, master_device)) 

264 # TODO(b/138389076): Think of the entering device scope behavior in the 

265 # failure recovery case when dealing with preemptions. 

266 if not current_device: 

267 logging.info("Entering into master device scope: %s", master_device) 

268 ops.device(master_device).__enter__() 

269 

270 

271def _strip_prefix(s, prefix): 

272 return s[len(prefix):] if s.startswith(prefix) else s