Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/accelerator_util.py: 24%
79 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility for working with accelerator systems."""
17from typing import List, Optional
19from absl import logging
21from tensorflow.core.protobuf import cluster_pb2
22from tensorflow.core.protobuf import tensorflow_server_pb2
23from tensorflow.dtensor.python import config
24from tensorflow.dtensor.python import tpu_util
25from tensorflow.python.eager import context
26from tensorflow.python.framework import config as tf_config
27from tensorflow.python.platform import remote_utils
28from tensorflow.python.util.tf_export import tf_export
30_INITIALIZED_ACCELERATOR_SYSTEM_TYPE = None
33def is_initialized() -> bool:
34 """Returns whether accelerator system has been initialized."""
35 return bool(_INITIALIZED_ACCELERATOR_SYSTEM_TYPE)
38def set_initialized(value):
39 """Sets if accelerator system has been initialized."""
40 global _INITIALIZED_ACCELERATOR_SYSTEM_TYPE
41 _INITIALIZED_ACCELERATOR_SYSTEM_TYPE = value
44def initialize_multi_client_cluster(job_name: str,
45 dtensor_jobs: List[str],
46 client_id: int,
47 collective_leader: str,
48 port: Optional[int] = None,
49 gpu_use_nccl_communication: bool = False,
50 enable_coordination_service: bool = True):
51 """Initialize GRPC servers and collectives for multi-client DTensor setup.
53 This function can be used to initialize a multi-client cluster and enable
54 collective ops. GRPC servers are necessary in the multi-client mode, even
55 when the number of clientis is 1.
57 NOTE: this function must be called in an eager context.
59 Args:
60 job_name: The job name used by all clients in the DTensor cluster.
61 dtensor_jobs: A list of the DTensor client jobs participating in the
62 cluster. Must be strings of the form "hostname:port".
63 client_id: The ID of the DTensor client this function is being called in.
64 collective_leader: The job/task that will be used to run collectives.
65 port: The port this client's GRPC server will run on. If omitted, use the
66 port from dtensor_jobs for this client.
67 gpu_use_nccl_communication: if True, configure TensorFlow to use NCCL by
68 default.
69 enable_coordination_service: If true, enable distributed coordination
70 service to make sure that workers know the devices on each other, a
71 prerequisite for data transfer through cross-worker rendezvous.
73 Raises:
74 RuntimeError: If running inside a tf.function.
75 """
76 assert context.executing_eagerly()
78 if not collective_leader.startswith("/job:"):
79 collective_leader = "/job:" + collective_leader
81 context.context().configure_collective_ops(
82 use_nccl_communication=gpu_use_nccl_communication,
83 collective_leader=collective_leader)
84 if enable_coordination_service:
85 context.context().configure_coordination_service(
86 service_type="standalone", service_leader=collective_leader)
88 config_proto = context.get_config()
90 # Construct server def from the host directly instead of relying on
91 # TF_CONFIG.
92 cluster_def = cluster_pb2.ClusterDef()
93 # Note that for bns addresses, we will currently rely on the sorted string
94 # of job name as the order of assigning task ids. This might be brittle once
95 # we have jobs across multiple cells.
96 cluster_def.job.add(name=job_name, tasks=dict(enumerate(dtensor_jobs)))
97 server_def = tensorflow_server_pb2.ServerDef(
98 cluster=cluster_def,
99 default_session_config=config_proto,
100 job_name=job_name,
101 task_index=client_id,
102 protocol=remote_utils.get_default_communication_protocol(),
103 port=port)
104 server_def.default_session_config.rpc_options.num_channels_per_target = 4
105 server_def.default_session_config.experimental.recv_buf_max_chunk = -1
107 logging.info("Enabling collectives with server_def: %s", server_def)
109 context.context().enable_collective_ops(server_def)
111 context.ensure_initialized()
114@tf_export(
115 "experimental.dtensor.initialize_accelerator_system",
116 "experimental.dtensor.initialize_tpu_system",
117 "experimental.dtensor.initialize_multi_client",
118 v1=[])
119def initialize_accelerator_system(
120 device_type: Optional[str] = None,
121 enable_coordination_service: Optional[bool] = True,
122 experimental_reset_context: Optional[bool] = False,
123) -> str:
124 """Initializes accelerators and communication fabrics for DTensor.
126 DTensor configures TensorFlow to run in the local mode or multi-client mode.
127 - In local mode, a mesh can only use devices attached to the current process.
128 - In multi-client mode, a mesh can span across devices from multiple clients.
130 If `DTENSOR_JOBS` is non-empty, DTensor configures TensorFlow to run in the
131 multi-client mode using the distributed runtime. In multi-client mode devices
132 on different clients can communicate with each other.
134 The following environment variables controls the behavior of this function.
136 - `DTENSOR_JOBS`: string, a comma separated list. Each item in the list is
137 of format `{hostname}:{port}`. If empty, DTensor runs in the local mode.
138 Examples of valid `DTENSOR_JOBS` values:
139 - 4 clients on localhost:
140 `localhost:10000,localhost:10001,localhost:10002,localhost:10003`
141 - 2 clients on host1, 2 clients on host2
142 `host1:10000,host1:10001,host2:10000,host2:10003`
143 If the hostnames are BNS addresses, the items must be sorted in
144 alphabetical order.
145 - `DTENSOR_CLIENT_ID`: integer, between `0` to `num_clients - 1`, to identify
146 the client id of the current process. The default value is `0`.
147 - `DTENSOR_JOB_NAME`: string, a string for the name of the TensorFlow job.
148 The job name controls the job name section of the TensorFlow DeviceSpecs,
149 e.g., `job:worker` in `/job:worker/replica:0/task:0/device:TPU:0` when
150 the job name is `worker`.
151 The default value is `localhost` in local mode, and
152 `worker` when in the multi-client mode. All DTensor clients within the
153 same multi-client cluster share the same job name.
154 - `DTENSOR_USE_PARALLEL_EXECUTOR`: string, with its value being `pw` to
155 specify that the backend is Pathways, and TensorFlow otherwise.
157 Args:
158 device_type: Type of accelerator to use, can be CPU, GPU, or TPU. If None,
159 uses `tf.experimental.dtensor.preferred_device_type()`.
160 enable_coordination_service: If true, enable distributed coordination
161 service to make sure that workers know the devices on each other, when
162 there is more than 1 client.
163 experimental_reset_context: Reset the tensorflow context. Behaviors of
164 existing TensorFlow objects (e.g. Tensors) are undefined. Set this to True
165 as an escape hatch, if there is no clear way to refactor your code to call
166 initialize_accelerator_system() before calling TensorFlow APIs that
167 initialize the context.
169 Returns:
170 device_type: the type of accelerator that was initialized.
171 """
172 global _INITIALIZED_ACCELERATOR_SYSTEM_TYPE
173 assert context.executing_eagerly()
175 if is_initialized():
176 raise ValueError(
177 "Accelerator system has already been initialized. "
178 "Call tf.experimental.dtensor.shutdown_accelerator_system() first.")
180 if experimental_reset_context:
181 if context.context()._initialized: # pylint: disable=protected-access
182 logging.warn(
183 "experimental_reset_context is True. "
184 "Resetting TensorFlow context. Existing TensorFlow objects "
185 "(e.g. Tensors and resources) are invalidated."
186 )
187 context.context().ensure_uninitialized()
189 if context.context()._initialized: # pylint: disable=protected-access
190 raise ValueError(
191 "TensorFlow has already been initialized. "
192 "tf.experimental.dtensor.initialize_accelerator_system() must be "
193 "called before TensorFlow is initialized.")
195 context.context()._clear_caches() # pylint: disable=protected-access
197 if device_type is None:
198 device_type = config.preferred_device_type()
200 device_type = device_type.upper()
201 if device_type not in {"CPU", "GPU", "TPU"}:
202 raise ValueError(f"Unknown device_type {device_type}. "
203 "Allowed values are CPU, GPU, or TPU")
205 if config.gpu_use_nccl_communication():
206 logical_gpu_count = config.num_local_devices("GPU")
207 physical_gpu_count = len(tf_config.list_physical_devices("GPU"))
208 if logical_gpu_count > physical_gpu_count:
209 raise ValueError(
210 "DTENSOR_GPU_USE_NCCL_COMMUNICATION is set for using NCCL. "
211 "NCCL Collectives require one to one mapping between logical and "
212 "physical GPUs. "
213 f"The number of logical GPU ({logical_gpu_count}) "
214 f"is more than the number of physical GPU ({physical_gpu_count})."
215 )
217 # Configure logical host CPU devices for accelerators.
218 if device_type in ("GPU", "TPU"):
219 num_local_devices = config.num_local_devices(device_type)
220 if config.num_local_devices("CPU") < num_local_devices:
221 tf_config.set_logical_device_configuration(
222 tf_config.list_physical_devices("CPU")[0],
223 [context.LogicalDeviceConfiguration()] * num_local_devices)
225 if not config.is_local_mode():
226 initialize_multi_client_cluster(
227 job_name=config.job_name(),
228 dtensor_jobs=config.jobs(),
229 client_id=config.client_id(),
230 collective_leader=config.full_job_name(task_id=0),
231 gpu_use_nccl_communication=config.gpu_use_nccl_communication(),
232 enable_coordination_service=enable_coordination_service)
233 else:
234 if device_type == "GPU":
235 # Enables Nccl on local mode.
236 context.context( # pylint: disable=protected-access
237 )._collective_use_nccl_communication = config.gpu_use_nccl_communication(
238 )
240 if device_type == "TPU" and not config.backend_is_pw():
241 tpu_util.initialize_tpu_system()
243 _INITIALIZED_ACCELERATOR_SYSTEM_TYPE = device_type
245 return device_type
248@tf_export(
249 "experimental.dtensor.shutdown_accelerator_system",
250 "experimental.dtensor.shutdown_tpu_system",
251 v1=[])
252def shutdown_accelerator_system() -> None:
253 """Shuts down the accelerator system."""
254 global _INITIALIZED_ACCELERATOR_SYSTEM_TYPE
255 context.async_wait()
257 if not is_initialized():
258 raise ValueError(
259 "Accelerator system is not initialized. Call "
260 "tf.experimental.dtensor.initialize_accelerator_system first.")
262 device_type = _INITIALIZED_ACCELERATOR_SYSTEM_TYPE
264 if not config.is_local_mode():
265 raise ValueError(
266 "Shutting down accelerator system under multi-client mode is "
267 "not supported.")
269 if device_type == "TPU" and not config.backend_is_pw():
270 tpu_util.shutdown_tpu_system()
272 # reset TF context to stop gRPC servers.
273 context._reset_context() # pylint: disable=protected-access
274 context.context()._clear_caches() # pylint: disable=protected-access
275 _INITIALIZED_ACCELERATOR_SYSTEM_TYPE = None